The vanilla VAE reveals distinct clusters whereas the CVAE has a additional homogeneous distribution. Vanilla VAE encodes class and class variation into the latent space since there isn’t a such factor as a supplied conditional signal. Nonetheless, the CVAE doesn’t must be taught class distinction and the latent space can provide consideration to the variation inside programs. As a consequence of this reality, a CVAE can doubtlessly be taught additional data as a result of it doesn’t depend upon having to be taught major class conditioning.
Two model architectures have been created to verify image period. The first construction was a convolutional CVAE with a concatenating conditional technique. All networks have been constructed for Fashion-MNIST photographs of measurement 28×28 (784 full pixels).
class ConcatConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
great().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, 32)
# Latent space (with concatenated scenario)
self.fc_mu = nn.Linear(self.flatten_size + 32, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 32, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + 32, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Concatenate scenario with encoded enter
x = torch.cat([x, c], dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c = self.label_embedding(c)
# Concatenate scenario with latent vector
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The CVAE encoder consists of three convolutional layers each adopted by a ReLU non-linearity. The output of the encoder is then flattened. The class amount is then handed by the use of an embedding layer and added to the encoder output. The reparameterization trick is then used with 2 linear layers to accumulate a μ and σ throughout the latent space. As quickly as sampled, the output of the reparameterized latent space is handed to the decoder now concatenated with the class amount embedding layer output. The decoder consists of three transposed convolutional layers. The first two embrace a ReLU non-linearity with the ultimate layer containing a sigmoid non-linearity. The output of the decoder is a 28×28 generated image.
The other model construction follows the equivalent technique nonetheless with together with the conditional enter as an alternative of concatenating. A critical question was if together with or concatenating will lead to increased reconstruction or period outcomes.
class AdditiveConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
great().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, self.flatten_size)
# Latent space (with out concatenation)
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_var = nn.Linear(self.flatten_size, latent_dim)
# Decoder scenario embedding
self.decoder_label_embedding = nn.Embedding(num_classes, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Add scenario to encoded enter
x = x + c
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
# Add scenario to latent vector
c = self.decoder_label_embedding(c)
z = z + c
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The equivalent loss function is used for all CVAEs from the equation confirmed above.
def loss_function(recon_x, x, mu, logvar):
"""Computes the loss = -ELBO = Opposed Log-Likelihood + KL Divergence.
Args:
recon_x: Decoder output.
x: Flooring reality.
mu: Suggest of Z
logvar: Log-Variance of Z
"""
BCE = F.binary_cross_entropy(recon_x, x, low cost="sum")
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
As a strategy to assess model-generated photographs, 3 quantitative metrics are typically used. Suggest Squared Error (MSE) was calculated by summing the squares of the excellence between the generated image and a flooring reality image pixel-wise. Structural Similarity Index Measure (SSIM) is a metric that evaluates image top quality by evaluating two photographs based totally on structural data, luminance, and distinction [3]. SSIM will be utilized to match photographs of any measurement whereas MSE is relative to pixel measurement. SSIM score ranges from -1 to 1, the place 1 signifies comparable photographs. Frechet inception distance (FID) is a metric for quantifying the realism and number of photographs generated. As FID is a distance measure, lower scores are indicative of a higher reconstruction of a set of photographs.
Sooner than scaling as a lot as full textual content material to image, CVAEs image reconstruction and period on Fashion-MNIST. Fashion-MNIST is an MNIST-like dataset consisting of a training set of 60,000 examples and a check out set of 10,000 examples. Each occasion is a 28×28 grayscale image, associated to a label from 10 programs [4].
Preprocessing options have been created to extract the associated key phrase containing the class establish from the enter short-text widespread expression matching. Extra descriptors (synonyms) have been used for a lot of programs to account for comparable development devices included in each class (e.g. Coat & Jacket).
programs = {
'Shirt':0,
'Excessive':0,
'Trouser':1,
'Pants':1,
'Pullover':2,
'Sweater':2,
'Hoodie':2,
'Robe':3,
'Coat':4,
'Jacket':4,
'Sandal':5,
'Shirt':6,
'Sneaker':7,
'Shoe':7,
'Bag':8,
'Ankle boot':9,
'Boot':9
}def word_to_text(input_str, programs, model, machine):
label = class_embedding(input_str, programs)
if label == -1: return Exception("No authentic label")
samples = sample_images(model, num_samples=4, label=label, machine=machine)
plot_samples(samples, input_str, torch.tensor([label]))
return
def class_embedding(input_str, programs):
for key in report(programs.keys()):
template = f'(?i)b{key}b'
output = re.search(template, input_str)
if output: return programs[key]
return -1
The class establish was then remodeled to its class amount and used as a result of the conditional enter to the CVAE alongside. As a strategy to generate an image, the class label extracted from the short textual content material description is handed into the decoder with random samples from a Gaussian distribution to enter the variable from the latent space.
Sooner than testing period, image reconstruction is examined to ensure the efficiency of the CVAE. Attributable to creating a convolutional group with 28×28 photographs, the group may very well be expert in decrease than an hour with decrease than 100 epochs.
Reconstructions embrace the ultimate type of the underside reality photographs, nonetheless sharp, extreme frequency choices are missing from the image. Any textual content material or intricate design patterns are blurred throughout the model output. Inputting any fast textual content material containing a class of Fashion-MNIST gives generated outputs resembling reconstructed photographs.
The generated photographs have an MSE of 11 and a SSIM of 0.76. These symbolize good generations signifying that in straightforward, small photographs, CVAEs can generate top quality photographs. GANs and DDPMs will produce bigger top quality photographs with sophisticated choices, nonetheless CVAEs can cope with straightforward circumstances.
When scaling as a lot as image period to textual content material of any measurement, additional sturdy methods may very well be wished other than widespread expression matching. To do this, Open AI’s CLIP is used to rework textual content material proper right into a extreme dimensional embedding vector. The embedding model is utilized in its ViT-B/32 configuration, which outputs embeddings of measurement 512. A limitation of the CLIP model is that it has a most token measurement of 77, with analysis displaying a superb smaller environment friendly measurement of 20 [5]. Thus, in conditions the place the enter textual content material contains quite a few sentences, the textual content material is lower up up by sentence and handed by the use of the CLIP encoder. The following embeddings are averaged collectively to create the final word output embedding.
An prolonged textual content material model requires far more subtle teaching data than Fashion-MNIST, so COCO dataset was used. COCO dataset has annotations (that aren’t completely sturdy nonetheless that may be talked about later) which may be handed into CLIP to get embeddings. Nonetheless, COCO photographs are of measurement 640×480, that implies that even with cropping transforms, a much bigger group is required. Together with and concatenating conditional inputs architectures are every examined for prolonged textual content material to image period, nonetheless the concatenating technique is confirmed proper right here:
class cVAE(nn.Module):
def __init__(self, latent_dim=128):
great().__init__()machine = torch.machine("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model, _ = clip.load("ViT-B/32", machine=machine)
self.clip_model.eval()
for param in self.clip_model.parameters():
param.requires_grad = False
self.latent_dim = latent_dim
# Modified encoder for 128x128 enter
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4x4
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 512 * 4 * 4 # Flattened measurement from encoder
# Course of CLIP embeddings for encoder
self.condition_processor_encoder = nn.Sequential(
nn.Linear(512, 1024)
)
self.fc_mu = nn.Linear(self.flatten_size + 1024, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 1024, latent_dim)
self.decoder_input = nn.Linear(latent_dim + 512, 512 * 4 * 4)
# Modified decoder for 128x128 output
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 128x128
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 3, 3, stride=1, padding=1), # 128x128
nn.Sigmoid()
)
def encode_condition(self, textual content material):
with torch.no_grad():
embeddings = []
for sentence in textual content material:
embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).kind(torch.float32))
return torch.suggest(torch.stack(embeddings), dim=0)
def encode(self, x, c):
x = self.encoder(x)
c = self.condition_processor_encoder(c)
x = torch.cat([x, c], dim=1)
return self.fc_mu(x), self.fc_var(x)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 512, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
One different most important degree of investigation was image period and reconstruction on photographs of varied sizes. Notably, modifying COCO photographs to be of measurement 64×64, 128×128, and 256×256. After teaching the group, reconstruction outcomes must first be examined.
All image sizes lead to reconstructed background with some perform outlines and correct colors. Nonetheless, as image measurement will improve, additional choices are ready to be recovered. That is good as if it might take hundreds longer to teach a model with a much bigger image measurement, there could also be additional data which may be captured and realized by the model.
With image period, this may be very troublesome to generate high-quality photographs. Most photographs have backgrounds to some extent and blurred choices throughout the image. This can be anticipated for image period from a CVAE. This occurs in every concatenation and addition for the conditional enter, nonetheless the concatenated technique performs increased. That’s potential because of concatenated conditional inputs received’t intrude with very important choices and ensures data is preserved distinctly. Conditions may very well be ignored in the event that they’re irrelevant. Nonetheless, additive conditional inputs can intrude with current choices and completely mess up the group when updating weights all through backpropagation.
All of the COCO generated photographs have a far lower SSIM of about 0.4 compared with the SSIM on Fashion-MNIST. MSE is proportional to image measurement, so it’s troublesome to quanity variations. FID for COCO image generations are throughout the 200s for added proof that COCO CVAE generated photographs normally should not sturdy.
A very powerful limitation in making an attempt to utilize CVAEs for image period is, successfully, the CVAE. The amount of information which may be contained and reconstructed/generated is very relying on the size of the latent space. A latent space that’s too small obtained’t seize any vital data and is proportional to the size of the output image. A 28×28 image needs a a lot smaller latent space than a 64×64 image (as a result of it proportionally squares from image measurement). Nonetheless, a latent space bigger than the exact image offers pointless knowledge and at the moment merely create a 1-to-1 mapping. For the COCO dataset, a latent space of at least 512 is required to grab some choices. And whereas CVAEs are generative fashions, a convolutional encoder and decoder is a comparatively rudimentary group. The teaching trend of a GAN or the sophisticated denoising technique of a DDPM permits for lots additional subtle image period.
One different most important limitation in image period is the dataset expert on. Although the COCO dataset has annotations, the annotations normally should not extensively detailed. As a strategy to put together sophisticated generative fashions, a definite dataset ought for use for teaching. COCO doesn’t current areas or additional data for background particulars. A flowery perform vector from the CLIP encoder can’t be efficiently utilized to a CVAE on COCO.
Although CVAEs and movie period on COCO have their limitations, it creates a workable image period model. Further code and particulars may very well be supplied merely attain out!
[1] Kingma, Diederik P, et. al. “Auto-encoding variational bayes.” arXiv:1312.6114 (2013).
[2] Sohn, Kihyuk, et. al. “Finding out Structured Output Illustration using Deep Conditional Generative Fashions.” NeurIPS Proceedings (2015).
[3] Nilsson, J., et. al. “Understanding ssim.” arXiv:2102.12037 (2020).
[4] Xiao, Han, et. al. “Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms.” arXiv:2403.15378 (2024) (MIT license).
[5] Zhang, B., et. al. “Prolonged-clip: Unlocking the long-text performance of clip.” arXiv:2403.15378 (2024).
A reference to my group mission companions Jake Hession (Deloitte Advertising marketing consultant), Ashley Hong (Google SWE), and Julian Kuppel (Quant)!