GEGT Application
This section provides technical details about the GEGT (Generalized Encoder-Generator for Text) applet. It includes the supported data format, the model architecture used, training process, and deployment instructions for using the trained .bin
file.
Data Processing Details
GEGT supports a specific format for image-to-text generation tasks. Ensure your data follows the structure below:
Folder Format
The dataset should be organized in a folder that includes:
- A subfolder named
input
containing image files named1.jpg
,2.jpg
,3.jpg
, etc. - A text file named
output.txt
where each line corresponds to the prompt or expected output for the image with the same number.
Example structure:
/your_dataset_folder/
├── input/
│ ├── 1.jpg
│ ├── 2.jpg
│ └── 3.jpg
└── output.txt
output.txt
Describe the object in image 1.
Describe the object in image 2.
Describe the object in image 3.
Model Architecture
GEGT fine-tunes a combined Vision-Text Transformer model using a VisionEncoderDecoderModel
. The model integrates a ViT encoder with a BERT decoder.
class FineTunedTransformerGEGT(nn.Module):
def __init__(self, vocab_size):
super(FineTunedTransformerGEGT, self).__init__()
config_encoder = ViTConfig()
config_decoder = BertConfig()
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
self.model = VisionEncoderDecoderModel(config=config)
self.lm_head = nn.Linear(self.model.config.decoder.hidden_size, vocab_size)
def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
encoder_outputs = self.model.encoder(pixel_values)
decoder_outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs.last_hidden_state,
output_hidden_states=True
)
decoder_last_hidden_state = decoder_outputs.hidden_states[-1]
logits = self.lm_head(decoder_last_hidden_state)
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss
return logits
Training Details
GEGT is trained using CrossEntropyLoss
and the Adam
optimizer. Training involves processing batches of images and corresponding text labels.
def train(self, qthread=None, progress_updated=None, loss_updated=None):
total_steps = self.num_epochs
current_step = 0
for epoch in range(self.num_epochs):
total_loss = 0
for pixel_vals, labels in self.train_loader:
self.optimizer.zero_grad()
pixel_values = pixel_vals.to(self.device)
labels = labels.to(self.device)
outputs = self.model(pixel_values=pixel_values, input_ids=labels)
logits = outputs
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
total_loss += loss.item()
loss.backward()
self.optimizer.step()
current_step += 1
progress = int((current_step / total_steps) * 100)
progress_updated.emit(progress)
qthread.msleep(1)
avg_loss = total_loss / len(self.train_loader)
loss_updated.emit(avg_loss)
print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {avg_loss:.4f}")
Evaluation Details
GEGT model uses perplexity, the code for the evaluation loop can be found below:
def evaluate(self, eval_dataset=None):
self.model.eval()
total_loss = 0
total_tokens = 0
val_loader = DataLoader(eval_dataset, batch_size=self.batch_size, shuffle=False) if eval_dataset else self.val_loader
with torch.no_grad():
for pixel_vals, labels in val_loader:
pixel_values = pixel_vals.to(self.device)
labels = labels.to(self.device)
outputs = self.model(pixel_values=pixel_values, input_ids=labels)
logits = outputs
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
total_loss += loss.item()
total_tokens += shift_labels.numel()
perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
return perplexity.item()
Deployment Details
Use the following method to load the trained model and generate text from a new input image.
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
class ImageToTextGenerator:
def __init__(self, model_path, processor_name, tokenizer_name, max_length=50):
"""
Initialize the vision-text model and required processors.
Parameters:
- model_path (str): Path to the saved .bin model weights.
- processor_name (str): Name of the vision processor (e.g., ViTImageProcessor).
- tokenizer_name (str): Name of the tokenizer for the decoder.
- max_length (int): Maximum length of generated text.
"""
self.processor = ViTImageProcessor.from_pretrained(processor_name)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(processor_name, tokenizer_name)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
self.model.eval()
self.max_length = max_length
def generate_text(self, image_path):
"""
Generate text from an image file.
Parameters:
- image_path (str): Path to the image.
Returns:
- str: Generated text description.
"""
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt")
input_ids = torch.tensor([[self.tokenizer.pad_token_id]])
with torch.no_grad():
outputs = self.model(pixel_values=inputs["pixel_values"], input_ids=input_ids)
logits = outputs
predicted_ids = torch.argmax(logits, dim=-1)
return self.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
# Example usage
model_path = "" # path to .bin file
processor_name = "google/vit-base-patch16-224-in21k" # example vision encoder
tokenizer_name = "bert-base-uncased" # example text decoder
image_path = "" # path to your image file
generator = ImageToTextGenerator(model_path, processor_name, tokenizer_name)
output = generator.generate_text(image_path)
print(output)