import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input
from tensorflow.keras.datasets import mnist

(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train / 127.5) - 1.0
X_train = X_train.reshape(-1, 784)

generator = Sequential([
  Input(shape=(100,)),
  Dense(128, activation='relu'),
  Dense(784, activation='tanh')
])

discriminator = Sequential([
  Input(shape=(784,)),
  Dense(128, activation='relu'),
  Dense(1, activation='sigmoid')
])

discriminator.compile(optimizer='adam', loss='binary_crossentropy')

discriminator.trainable = False
gan_input = Input(shape=(100,))
fake = generator(gan_input)
gan_output = discriminator(fake)

gan = Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')

for epoch in range(2000):
  idx = np.random.randint(0, X_train.shape[0], 32)
  real = X_train[idx]
  noise = np.random.normal(0, 1, (32, 100))
  fake = generator.predict(noise, verbose=0)
  discriminator.train_on_batch(real, np.ones((32, 1)))
  discriminator.train_on_batch(fake, np.zeros((32, 1)))
  gan.train_on_batch(noise, np.ones((32, 1)))

noise = np.random.normal(0, 1, (5, 100))
gen_imgs = generator.predict(noise)

for i in range(5):
  plt.subplot(1, 5, i+1)
  plt.imshow(gen_imgs[i].reshape(28, 28), cmap='gray')
  plt.axis('off')
  plt.show()