5分钟入门GANS:原理解释和keras代码实现( 二 )


# Combined networkdiscriminator.trainable = FalseganInput = Input(shape=(randomDim,))x = generator(ganInput)ganOutput = discriminator(x)gan = Model(inputs=ganInput, outputs=ganOutput)gan.compile(loss='binary_crossentropy', optimizer=adam)dLosses = []gLosses = []三个函数,每20个epoch绘制并保存结果,并保存模型 。
# Plot the loss from each batchdef plotLoss(epoch):plt.figure(figsize=(10, 8))plt.plot(dLosses, label='Discriminitive loss')plt.plot(gLosses, label='Generative loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.savefig('images/gan_loss_epoch_%d.png' % epoch)# Create a wall of generated MNIST imagesdef plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):noise = np.random.normal(0, 1, size=[examples, randomDim])generatedImages = generator.predict(noise)generatedImages = generatedImages.reshape(examples, 28, 28)plt.figure(figsize=figsize)for i in range(generatedImages.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')plt.axis('off')plt.tight_layout()plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)# Save the generator and discriminator networks (and weights) for later usedef saveModels(epoch):generator.save('models/gan_generator_epoch_%d.h5' % epoch)discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)训练函数
def train(epochs=1, batchSize=128):batchCount = X_train.shape[0] / batchSizeprint 'Epochs:', epochsprint 'Batch size:', batchSizeprint 'Batches per epoch:', batchCountfor e in xrange(1, epochs+1):print '-'*15, 'Epoch %d' % e, '-'*15for _ in tqdm(xrange(batchCount)):# Get a random set of input noise and imagesnoise = np.random.normal(0, 1, size=[batchSize, randomDim])imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]# Generate fake MNIST imagesgeneratedImages = generator.predict(noise)# print np.shape(imageBatch), np.shape(generatedImages)X = np.concatenate([imageBatch, generatedImages])# Labels for generated and real datayDis = np.zeros(2*batchSize)# One-sided label smoothingyDis[:batchSize] = 0.9# Train discriminatordiscriminator.trainable = Truedloss = discriminator.train_on_batch(X, yDis)# Train generatornoise = np.random.normal(0, 1, size=[batchSize, randomDim])yGen = np.ones(batchSize)discriminator.trainable = Falsegloss = gan.train_on_batch(noise, yGen)# Store loss of most recent batch from this epochdLosses.Append(dloss)gLosses.append(gloss)if e == 1 or e % 20 == 0:plotGeneratedImages(e)saveModels(e)# Plot losses from every epochplotLoss(e)至此一个简单的GAN已经完成了,完整的代码在这里找到
github/bhaveshgoyal27/mediumblogs/blob/master/Keras


推荐阅读