模型训练、图片生成和模型的保存:
with tf.Session(config=config) as sess:for d in ['/gpu:0']:with tf.device(d):ckpt = tf.train.get_checkpoint_state('./models/')if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):saver.restore(sess, ckpt.model_checkpoint_path)print('Import models successful!')else:sess.run(tf.global_variables_initializer)print('Initialize successful!')for i in range(epoch):random.shuffle(img_label_org)random.shuffle(label_trg)for j in range(n_batch):if j == n_batch - 1:n = total_sample_numelse:n = j * batch_size + batch_sizeimg_org_output, img_trg_output, label_org_output, label_trg_output, image_name_output = reader.images_read(img_label_org[j*batch_size:n], label_trg[j*batch_size:n], img, imagedir)feeds = {org_image:img_org_output, trg_image:img_trg_output, org_pose:label_org_output,trg_pose:label_trg_output}if i < 400:sess.run(train_disc, feed_dict=feeds)sess.run(train_gen, feed_dict=feeds)sess.run(train_out, feed_dict=feeds)else:sess.run(train_gen, feed_dict=feeds)sess.run(train_out, feed_dict=feeds)if j%10==0:sess.run(train_disc, feed_dict=feeds)if j%2==0:gen_g_loss_,out_g_loss_, disc_loss_, org_image_, gen_trg_, out_trg_, trg_image_ = sess.run([gen_g_loss, out_g_loss, disc_loss, org_image, gen_trg, out_trg, trg_image],feeds)print("epoch:", i, "iter:", j, "gen_g_loss_:", gen_g_loss_, "out_g_loss_:", out_g_loss_, "loss_disc:", disc_loss_)for n in range(batch_size):org_image_output = (org_image_[n] + 1)*127.5gen_trg_output = (gen_trg_[n] + 1)*127.5out_trg_output = (out_trg_[n] + 1)*127.5trg_image_output = (trg_image_[n] + 1)*127.5temp = np.concatenate([org_image_output, gen_trg_output, out_trg_output, trg_image_output], 1)cv.imwrite("./record/%d_%d_%d_image.jpg" %(i, j, n), temp)if i%10==0 or i==epoch-1:saver.save(sess, './models/wssGAN.ckpt', global_step=gen_global_step)print("Finish!")
最终运行程序结果如下:
初始训练一次结果:
文章插图
训练20次结果:
文章插图
经过对比,可以发现有明显的提升!
源码地址:
https://pan.baidu.com/s/1cpRJlk7yUwhYJSIkRpbNpg
提取码:kdxe
作者介绍:
李秋键,CSDN 博客专家,CSDN达人课作者 。硕士在读于中国矿业大学,开发有taptapAndroid/ target=_blank class=infotextkey>安卓武侠游戏一部,vip视频解析,文意转换工具,写作机器人等项目,发表论文若干,多次高数竞赛获奖等等 。
推荐阅读
- SpringBoot常用属性配置
- 2020年适用于任何团队的5大数据库文档工具
- C,Java和Python之间的性能比较
- Linux操作系统中常用调度算法
- Python通过MySQLdb访问操作MySQL数据库
- Bash技巧:介绍一个可以增删改查键值对格式配置文件的Shell脚本
- 恢复AD用户误删,给你3种方案
- 使用 Mailmerge 发送定制邮件
- 在Python中使用Torchmoji将文本转换为表情符号
- 适合数据库初级人员 常用的sql语句集合