images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].set_title("Image")
ax[0].imshow(test_image / 255.0)
ax[1].set_title("Image with segmentation mask overlay")
ax[1].imshow(test_image / 255.0)
ax[1].imshow(test_mask,cmap="inferno",alpha=0.6,)
plt.show()