modeling_unimernet.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. import warnings
  3. from typing import Optional
  4. import torch
  5. from ftfy import fix_text
  6. from loguru import logger
  7. from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
  8. from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
  9. from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import (
  10. VisionEncoderDecoderModel,
  11. logger as base_model_logger,
  12. )
  13. from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
  14. from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
  15. from ...utils import latex_rm_whitespace
  16. AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
  17. AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
  18. AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
  19. AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
  20. # TODO: rewrite tokenizer
  21. class TokenizerWrapper:
  22. def __init__(self, tokenizer):
  23. self.tokenizer = tokenizer
  24. self.pad_token_id = self.tokenizer.pad_token_id
  25. self.bos_token_id = self.tokenizer.bos_token_id
  26. self.eos_token_id = self.tokenizer.eos_token_id
  27. def __len__(self):
  28. return len(self.tokenizer)
  29. def tokenize(self, text, **kwargs):
  30. return self.tokenizer(
  31. text,
  32. return_token_type_ids=False,
  33. return_tensors="pt",
  34. padding="longest",
  35. truncation=True,
  36. **kwargs,
  37. )
  38. def token2str(self, tokens) -> list:
  39. generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
  40. generated_text = [fix_text(text) for text in generated_text]
  41. return generated_text
  42. def detokenize(self, tokens):
  43. toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
  44. for b in range(len(toks)):
  45. for i in reversed(range(len(toks[b]))):
  46. if toks[b][i] is None:
  47. toks[b][i] = ''
  48. toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
  49. if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
  50. del toks[b][i]
  51. return toks
  52. class UnimernetModel(VisionEncoderDecoderModel):
  53. def __init__(
  54. self,
  55. config: Optional[PretrainedConfig] = None,
  56. encoder: Optional[PreTrainedModel] = None,
  57. decoder: Optional[PreTrainedModel] = None,
  58. ):
  59. # VisionEncoderDecoderModel's checking log has bug, disable for temp.
  60. base_model_logger.disabled = True
  61. try:
  62. super().__init__(config, encoder, decoder)
  63. finally:
  64. base_model_logger.disabled = False
  65. if not config or not hasattr(config, "_name_or_path"):
  66. raise RuntimeError("config._name_or_path is required by UnimernetModel.")
  67. model_path = config._name_or_path
  68. self.transform = UnimerSwinImageProcessor()
  69. self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
  70. self._post_check()
  71. def _post_check(self):
  72. tokenizer = self.tokenizer
  73. if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
  74. warnings.warn(
  75. f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
  76. f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
  77. f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
  78. tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
  79. assert self.config.decoder.vocab_size == len(tokenizer)
  80. assert self.config.decoder_start_token_id == tokenizer.bos_token_id
  81. assert self.config.pad_token_id == tokenizer.pad_token_id
  82. @classmethod
  83. def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
  84. config = VisionEncoderDecoderConfig.from_pretrained(model_path)
  85. config._name_or_path = model_path
  86. config.encoder = UnimerSwinConfig(**vars(config.encoder))
  87. config.decoder = UnimerMBartConfig(**vars(config.decoder))
  88. encoder = UnimerSwinModel(config.encoder)
  89. decoder = UnimerMBartForCausalLM(config.decoder)
  90. model = cls(config, encoder, decoder)
  91. # load model weights
  92. model_file_path = os.path.join(model_path, model_filename)
  93. checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
  94. state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
  95. if not state_dict:
  96. raise RuntimeError("state_dict is empty.")
  97. if state_dict_strip_prefix:
  98. state_dict = {
  99. k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
  100. for k, v in state_dict.items()
  101. }
  102. missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
  103. if len(unexpected_keys) > 0:
  104. warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
  105. if len(missing_keys) > 0:
  106. raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
  107. return model
  108. def forward_bak(self, samples):
  109. pixel_values, text = samples["image"], samples["text_input"]
  110. text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
  111. decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
  112. num_channels = pixel_values.shape[1]
  113. if num_channels == 1:
  114. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  115. labels = decoder_input_ids * 1
  116. labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
  117. loss = self.model(
  118. pixel_values=pixel_values,
  119. decoder_input_ids=decoder_input_ids[:, :-1],
  120. decoder_attention_mask=decoder_attention_mask[:, :-1],
  121. labels=labels[:, 1:],
  122. ).loss
  123. return {"loss": loss}
  124. def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95, batch_size=64):
  125. pixel_values = samples["image"]
  126. num_channels = pixel_values.shape[1]
  127. if num_channels == 1:
  128. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  129. kwargs = {}
  130. if do_sample:
  131. kwargs["temperature"] = temperature
  132. kwargs["top_p"] = top_p
  133. if self.tokenizer.tokenizer.model_max_length > 1152:
  134. if batch_size <= 32:
  135. self.tokenizer.tokenizer.model_max_length = 1152 # 6g
  136. else:
  137. self.tokenizer.tokenizer.model_max_length = 1344 # 8g
  138. outputs = super().generate(
  139. pixel_values=pixel_values,
  140. max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
  141. decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
  142. do_sample=do_sample,
  143. **kwargs,
  144. )
  145. outputs = outputs[:, 1:].cpu().numpy()
  146. pred_tokens = self.tokenizer.detokenize(outputs)
  147. pred_str = self.tokenizer.token2str(outputs)
  148. fixed_str = [latex_rm_whitespace(s) for s in pred_str]
  149. return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}