迭代了14个实验, 其中

细节

version1

diff * 10

clip低于0.15给高惩罚,0.22为上限(因为之前发现0.25左右几乎和原文差不多,这样不会鼓励gfn创新)

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R

version2

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R

version3

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R

version4

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.clamp(torch.log(torch.clamp(diff_aes_scores, min=0.001)), min=-20) + torch.clamp(torch.log(torch.clamp(clip_scores, min=1e-4)), min=-100)
R = -0.01 * kls + R

version5

把KL和clip score 和 diff aes score整体看作R,一起算log

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 10
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)

version6

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.28, 5.6, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)

version7

clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
clip_scores = torch.where(clip_scores > 0.22, 4.4, 20 * clip_scores)
diff_aes_scores = diff_aes_scores * 100
R = torch.log(torch.clamp(-0.01 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)