迭代了14个实验, 其中
version1
diff * 10
clip低于0.15给高惩罚,0.22为上限(因为之前发现0.25左右几乎和原文差不多,这样不会鼓励gfn创新)
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R
version2
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R
version3
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R
version4
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R
version5
把KL和clip score 和 diff aes score整体看作R,一起算log
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)
version6
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)
version7
clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)