| | import torch |
| |
|
| | from typing import Tuple |
| | from dataclasses import dataclass |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| | from .csd import CSD |
| | from .config import CSDConfig |
| |
|
| |
|
| | @dataclass |
| | class CSDOutput: |
| | image_embeds: torch.Tensor |
| | style_embeds: torch.Tensor |
| | content_embeds: torch.Tensor |
| |
|
| |
|
| | class CSDModel(PreTrainedModel): |
| | config_class = CSDConfig |
| |
|
| | def __init__(self, config: CSDConfig) -> None: |
| | super(CSDModel, self).__init__(config) |
| |
|
| | self.model = CSD( |
| | vit_input_resolution=config.vit_input_resolution, |
| | vit_patch_size=config.vit_patch_size, |
| | vit_width=config.vit_width, |
| | vit_layers=config.vit_layers, |
| | vit_heads=config.vit_heads, |
| | vit_output_dim=config.vit_output_dim, |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput: |
| | image_embeds, style_embeds, content_embeds = self.model(pixel_values) |
| | return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds) |
| |
|