参数配置在前面,图在后面

变换temperature的4个version

clip_scores = torch.where(clip_scores < 0.2, 1e-4, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
diff_aes_scores = diff_aes_scores * 100
R = torch.log(torch.clamp(-0.2 * kls + clip_scores + diff_aes_scores, min = 0.001)).to(accelerator.device)

在这个基础上改temperature 0.7 0.5 0.3 0.1

elifself.args.reward_version =='v14': #v3
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 100
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, penalty_function(clip_scores), 100 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)

    final_scores = 0

    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

接下来12个version在v14的基础上 改配比,昨天发现aes score可以轻松涨上去,今天优化大方向是通过改配比和magic number,增大clip score重视程度

elifself.args.reward_version =='v15': #v3
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, penalty_function(clip_scores), 100 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v16': #v3
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, penalty_function(clip_scores), 20 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
  return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elif self.args.reward_version == 'v19': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1, self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 100
    original_clip_scores = clip_scores.reshape(-1, self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -20, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elif self.args.reward_version == 'v20': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1, self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 100
    original_clip_scores = clip_scores.reshape(-1, self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -100, 20 * clip_scores).reshape(-1, self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain
elifself.args.reward_version =='v21': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 50
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -100, 20 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v22': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 5
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -100, 2 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v23': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 50
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -100, 100 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v25': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 75
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.2, -100, 100 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v27': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 20
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.15, -100, 20 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v28': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 50
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.15, -100, 20 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain


29 30 num updates = 8
elif self.args.reward_version =='v29': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 100
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.15, 1e-4, 100 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

elifself.args.reward_version =='v30': #v2
    diff_aes_scores = (aes_scores - aes_scores_plain).reshape(-1,self.num_images_per_prompt).mean(1)
    diff_aes_scores = diff_aes_scores * 100
    original_clip_scores = clip_scores.reshape(-1,self.num_images_per_prompt).mean(1)
    clip_scores = torch.where(clip_scores < 0.15, 1e-4, 20 * clip_scores).reshape(-1,self.num_images_per_prompt).mean(1)
    final_scores = 0
    return final_scores, original_clip_scores, clip_scores, diff_aes_scores, aes_scores, original_clip_scores_plain

没有比昨天表现更突出的结果,记录了模型输出中间的变化

episode-greedy探索还没做(这个感觉可以sample下一个单词的时候随机在词表选一个或者从ref model选一个)

当aes score配比过高时,模型倾向于输出

Untitled

当clip score配比过高时,模型倾向于重复前面的物体,但是没有任何修饰词

Untitled