GEGT Application

This section provides technical details about the GEGT (Generalized Encoder-Generator for Text) applet. It includes the supported data format, the model architecture used, training process, and deployment instructions for using the trained .bin file.

Data Processing Details

GEGT supports a specific format for image-to-text generation tasks. Ensure your data follows the structure below:

Folder Format

The dataset should be organized in a folder that includes:

  • A subfolder named input containing image files named 1.jpg, 2.jpg, 3.jpg, etc.
  • A text file named output.txt where each line corresponds to the prompt or expected output for the image with the same number.

Example structure:

/your_dataset_folder/
├── input/
│   ├── 1.jpg
│   ├── 2.jpg
│   └── 3.jpg
└── output.txt

output.txt

Describe the object in image 1.
Describe the object in image 2.
Describe the object in image 3.

Model Architecture

GEGT fine-tunes a combined Vision-Text Transformer model using a VisionEncoderDecoderModel. The model integrates a ViT encoder with a BERT decoder.

class FineTunedTransformerGEGT(nn.Module):
    def __init__(self, vocab_size):
        super(FineTunedTransformerGEGT, self).__init__()

        config_encoder = ViTConfig()
        config_decoder = BertConfig()
        config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
        self.model = VisionEncoderDecoderModel(config=config)
        self.lm_head = nn.Linear(self.model.config.decoder.hidden_size, vocab_size)

    def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
        encoder_outputs = self.model.encoder(pixel_values)
        decoder_outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            output_hidden_states=True
        )
        decoder_last_hidden_state = decoder_outputs.hidden_states[-1]
        logits = self.lm_head(decoder_last_hidden_state)

        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            return loss

        return logits

Training Details

GEGT is trained using CrossEntropyLoss and the Adam optimizer. Training involves processing batches of images and corresponding text labels.

def train(self, qthread=None, progress_updated=None, loss_updated=None):
    total_steps = self.num_epochs
    current_step = 0

    for epoch in range(self.num_epochs):
        total_loss = 0
        for pixel_vals, labels in self.train_loader:
            self.optimizer.zero_grad()
            pixel_values = pixel_vals.to(self.device)
            labels = labels.to(self.device)

            outputs = self.model(pixel_values=pixel_values, input_ids=labels)
            logits = outputs
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss += loss.item()

            loss.backward()
            self.optimizer.step()

        current_step += 1
        progress = int((current_step / total_steps) * 100)
        progress_updated.emit(progress)
        qthread.msleep(1)

        avg_loss = total_loss / len(self.train_loader)
        loss_updated.emit(avg_loss)
        print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {avg_loss:.4f}")

Evaluation Details

GEGT model uses perplexity, the code for the evaluation loop can be found below:

def evaluate(self, eval_dataset=None):
    self.model.eval()
    total_loss = 0
    total_tokens = 0

    val_loader = DataLoader(eval_dataset, batch_size=self.batch_size, shuffle=False) if eval_dataset else self.val_loader

    with torch.no_grad():
        for pixel_vals, labels in val_loader:
            pixel_values = pixel_vals.to(self.device)
            labels = labels.to(self.device)

            outputs = self.model(pixel_values=pixel_values, input_ids=labels)
            logits = outputs
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss += loss.item()
            total_tokens += shift_labels.numel()

    perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
    return perplexity.item()

Deployment Details

Use the following method to load the trained model and generate text from a new input image.

import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image

class ImageToTextGenerator:
    def __init__(self, model_path, processor_name, tokenizer_name, max_length=50):
        """
        Initialize the vision-text model and required processors.

        Parameters:
        - model_path (str): Path to the saved .bin model weights.
        - processor_name (str): Name of the vision processor (e.g., ViTImageProcessor).
        - tokenizer_name (str): Name of the tokenizer for the decoder.
        - max_length (int): Maximum length of generated text.
        """
        self.processor = ViTImageProcessor.from_pretrained(processor_name)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(processor_name, tokenizer_name)

        self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
        self.model.eval()
        self.max_length = max_length

    def generate_text(self, image_path):
        """
        Generate text from an image file.

        Parameters:
        - image_path (str): Path to the image.

        Returns:
        - str: Generated text description.
        """
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")
        input_ids = torch.tensor([[self.tokenizer.pad_token_id]])

        with torch.no_grad():
            outputs = self.model(pixel_values=inputs["pixel_values"], input_ids=input_ids)
            logits = outputs
            predicted_ids = torch.argmax(logits, dim=-1)

        return self.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

# Example usage
model_path = ""  # path to .bin file
processor_name = "google/vit-base-patch16-224-in21k"  # example vision encoder
tokenizer_name = "bert-base-uncased"  # example text decoder
image_path = ""  # path to your image file

generator = ImageToTextGenerator(model_path, processor_name, tokenizer_name)
output = generator.generate_text(image_path)
print(output)