model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import math
  2. import re
  3. from typing import Iterable, List, Optional, Tuple
  4. import numpy as np
  5. import torch
  6. from sglang.srt.layers.quantization.base_config import QuantizationConfig
  7. from sglang.version import __version__ as sglang_version
  8. if sglang_version >= "0.4.9":
  9. # sglang >= 0.4.9
  10. from sglang.srt.multimodal.mm_utils import (
  11. get_anyres_image_grid_shape,
  12. )
  13. else:
  14. # 0.4.7 <= sglang < 0.4.9
  15. from sglang.srt.mm_utils import (
  16. get_anyres_image_grid_shape,
  17. )
  18. from sglang.srt.model_executor.forward_batch_info import ForwardBatch
  19. from sglang.srt.model_loader.weight_utils import default_weight_loader
  20. from sglang.srt.models.qwen2 import Qwen2ForCausalLM
  21. from sglang.srt.utils import add_prefix
  22. from torch import nn
  23. from transformers import (
  24. CLIPVisionConfig,
  25. CLIPVisionModel,
  26. SiglipVisionConfig,
  27. SiglipVisionModel,
  28. )
  29. from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
  30. from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
  31. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  32. def flatten_nested_list(nested_list):
  33. if isinstance(nested_list, list):
  34. return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
  35. else:
  36. return [nested_list]
  37. def downgrade_modality(modality):
  38. modality_str = str(modality)
  39. if "MULTI_IMAGES" in modality_str:
  40. return "multi-images"
  41. if "IMAGE" in modality_str:
  42. return "image"
  43. if "VIDEO" in modality_str:
  44. return "video"
  45. if "AUDIO" in modality_str:
  46. return "audio"
  47. raise ValueError(f"Unexpected modality: {modality_str}")
  48. class Mineru2QwenForCausalLM(nn.Module):
  49. def __init__(
  50. self,
  51. config: Mineru2QwenConfig,
  52. quant_config: Optional[QuantizationConfig] = None,
  53. prefix: str = "",
  54. ) -> None:
  55. super().__init__()
  56. self.config = config
  57. if getattr(self.config, "projector_hidden_act", None) is None:
  58. self.config.projector_hidden_act = "gelu"
  59. if getattr(self.config, "image_token_index", None) is None:
  60. self.config.image_token_index = 151646
  61. # load vision tower
  62. mm_vision_tower = self.config.mm_vision_tower
  63. model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
  64. mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
  65. if "clip" in mm_vision_tower:
  66. vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
  67. self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
  68. elif "siglip" in mm_vision_tower:
  69. vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
  70. self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
  71. # Siglip needs all feature tokens
  72. self.config.mm_vision_select_feature = "full"
  73. else:
  74. raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
  75. ### EDIT: change projector
  76. # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
  77. self.multi_modal_mlp = build_vision_projector(config)
  78. self.language_model = Qwen2ForCausalLM(
  79. config,
  80. quant_config=quant_config,
  81. prefix=add_prefix("language_model", prefix),
  82. )
  83. if "unpad" in getattr(config, "mm_patch_merge_type", ""):
  84. self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
  85. language_model_device = next(self.language_model.parameters()).device
  86. self.vision_tower = self.vision_tower.to(language_model_device)
  87. self.vision_tower.eval()
  88. self.vision_feature_layer = self.config.mm_vision_select_layer
  89. self.vision_feature_select_strategy = self.config.mm_vision_select_feature
  90. self.image_size = self.vision_tower.config.image_size
  91. self.patch_size = self.vision_tower.config.patch_size
  92. self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
  93. self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
  94. self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
  95. self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
  96. if self.vision_feature_select_strategy in ("patch", "full"):
  97. pass
  98. elif self.vision_feature_select_strategy == "cls_patch":
  99. self.image_feature_len += 1
  100. else:
  101. raise ValueError(f"Unexpected select feature: {self.select_feature}")
  102. def pad_input_ids(self, input_ids: List[int], image_inputs):
  103. image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
  104. pad_values = [item.pad_value for item in image_inputs.mm_items]
  105. # hardcode for spatial_unpad + anyres
  106. # if image_inputs.modalities is not None and (
  107. # "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
  108. # ):
  109. # image_aspect_ratio = "pad"
  110. # else:
  111. # image_aspect_ratio = "anyres"
  112. offset_list = []
  113. image_inputs.image_pad_len = []
  114. for image_idx, image_s in enumerate(image_sizes):
  115. if len(image_sizes) > 16:
  116. # 2x2 pooling with stride 2
  117. new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
  118. else:
  119. new_image_feature_len = self.image_feature_len # multiimage
  120. height = width = self.num_patches_per_side
  121. if "anyres" in self.config.image_aspect_ratio:
  122. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  123. image_s,
  124. self.image_grid_pinpoints,
  125. self.vision_tower.config.image_size,
  126. )
  127. h = num_patch_height * height
  128. w = num_patch_width * width
  129. ### EDIT: remove `unpad_image_shape`
  130. # new_h, new_w = unpad_image_shape(h, w, image_s)
  131. new_h, new_w = h, w
  132. if "anyres_max" in self.config.image_aspect_ratio:
  133. matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
  134. if matched_anyres_max_num_patches:
  135. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  136. times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
  137. if times > 1.1:
  138. new_h = int(new_h // times)
  139. new_w = int(new_w // times)
  140. new_image_feature_len += new_h * (new_w + 1)
  141. try:
  142. offset = input_ids.index(self.config.image_token_index)
  143. except ValueError:
  144. offset = 0
  145. # old_len + pad_len - 1, because we need to remove image_token_id
  146. input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
  147. offset_list.append(offset)
  148. image_inputs.image_pad_len.append(new_image_feature_len)
  149. image_inputs.image_offsets = offset_list
  150. return input_ids
  151. def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
  152. pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
  153. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  154. # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
  155. selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
  156. if self.vision_feature_select_strategy in ["default", "patch"]:
  157. selected_image_feature = selected_image_feature[:, 1:]
  158. elif self.vision_feature_select_strategy == "full":
  159. selected_image_feature = selected_image_feature
  160. else:
  161. raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
  162. image_features = self.multi_modal_mlp(selected_image_feature)
  163. return image_features
  164. @torch.no_grad()
  165. def forward(
  166. self,
  167. input_ids: torch.LongTensor,
  168. positions: torch.Tensor,
  169. forward_batch: ForwardBatch,
  170. ) -> torch.Tensor:
  171. image_inputs = forward_batch.mm_inputs
  172. if image_inputs is None:
  173. image_inputs = []
  174. if forward_batch.forward_mode.is_extend():
  175. # Clamp input ids. This is because the input_ids for the image tokens are
  176. # filled with the hash values of the image for the prefix matching in the radix attention.
  177. # There values are useless because their embeddings will be replaced by vision embeddings anyway.
  178. input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
  179. # Embed text inputs
  180. input_embeds = self.language_model.model.embed_tokens(input_ids)
  181. # Got List[List[str]] extend it to List[str]
  182. # The length of the List should be equal to batch size
  183. modalities_list = []
  184. max_image_offset = []
  185. for im in image_inputs:
  186. if im:
  187. modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
  188. if im and im.image_offsets:
  189. max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
  190. else:
  191. max_image_offset.append(-1)
  192. start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
  193. need_vision = start_positions <= np.array(max_image_offset)
  194. if need_vision.any():
  195. bs = forward_batch.batch_size
  196. if sglang_version >= "0.4.9.post3":
  197. # sglang >= 0.4.9.post3
  198. pixel_values = flatten_nested_list(
  199. [[item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
  200. ) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
  201. image_sizes = [
  202. flatten_nested_list([item.model_specific_data["image_sizes"] for item in image_inputs[i].mm_items])
  203. for i in range(bs)
  204. if need_vision[i]
  205. ] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
  206. else:
  207. # 0.4.7 <= sglang <= 0.4.9.post2
  208. pixel_values = flatten_nested_list(
  209. [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
  210. ) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
  211. image_sizes = [
  212. flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
  213. for i in range(bs)
  214. if need_vision[i]
  215. ] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
  216. ########## Encode Image ########
  217. if pixel_values[0].ndim == 4:
  218. # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
  219. np.concatenate(pixel_values, axis=0)
  220. # ndim=4
  221. concat_images = torch.tensor(
  222. np.concatenate(pixel_values, axis=0),
  223. device=self.vision_tower.device,
  224. )
  225. image_features = self.encode_images(concat_images)
  226. split_sizes = [image.shape[0] for image in pixel_values]
  227. image_features = torch.split(image_features, split_sizes, dim=0)
  228. # hd image_features: BS, num_patch, 576, 4096
  229. else:
  230. # normal pixel: BS, C=3, H=336, W=336
  231. pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
  232. image_features = self.encode_images(pixel_values)
  233. # image_features: BS, 576, 4096
  234. if self.mm_patch_merge_type.startswith("spatial"):
  235. new_image_features = []
  236. height = width = self.num_patches_per_side
  237. for image_idx, image_feature in enumerate(image_features):
  238. if modalities_list[image_idx] == "image":
  239. image_aspect_ratio = self.config.image_aspect_ratio # single image
  240. elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
  241. image_aspect_ratio = "pad" # multi image
  242. # image_aspect_ratio = (
  243. # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
  244. # )
  245. if (
  246. image_feature.shape[0] > 1
  247. and "anyres" in image_aspect_ratio
  248. and modalities_list[image_idx] == "image"
  249. ):
  250. base_image_feature = image_feature[0]
  251. image_feature = image_feature[1:]
  252. assert height * width == base_image_feature.shape[0]
  253. if "anyres_max" in image_aspect_ratio:
  254. matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
  255. if matched_anyres_max_num_patches:
  256. max_num_patches = int(matched_anyres_max_num_patches.group(1))
  257. if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  258. vision_tower_image_size = self.image_size
  259. try:
  260. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  261. image_sizes[image_idx][0],
  262. self.config.image_grid_pinpoints,
  263. vision_tower_image_size,
  264. )
  265. except Exception as e:
  266. print(f"Error: {e}")
  267. num_patch_width, num_patch_height = 2, 2
  268. image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
  269. else:
  270. image_feature = image_feature.view(2, 2, height, width, -1)
  271. if "unpad" in self.mm_patch_merge_type:
  272. unit = image_feature.shape[2]
  273. image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
  274. image_feature = image_feature.flatten(1, 2).flatten(2, 3)
  275. ### EDIT: remove `unpad_image`
  276. # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
  277. if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
  278. c, h, w = image_feature.shape
  279. times = math.sqrt(h * w / (max_num_patches * unit**2))
  280. if times > 1.1:
  281. image_feature = image_feature[None]
  282. image_feature = nn.functional.interpolate(
  283. image_feature,
  284. [int(h // times), int(w // times)],
  285. mode="bilinear",
  286. )[0]
  287. image_feature = torch.cat(
  288. (
  289. image_feature,
  290. self.language_model.model.image_newline[:, None, None].expand(
  291. *image_feature.shape[:-1], 1
  292. ),
  293. ),
  294. dim=-1,
  295. )
  296. image_feature = image_feature.flatten(1, 2).transpose(0, 1)
  297. else:
  298. image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
  299. image_feature = image_feature.flatten(0, 3)
  300. image_feature = torch.cat((base_image_feature, image_feature), dim=0)
  301. image_feature = image_feature.unsqueeze(0)
  302. else:
  303. if modalities_list[image_idx] == "video": # video
  304. # 2x2 pooling
  305. num_of_frames = image_feature.shape[0]
  306. image_feature = image_feature.view(num_of_frames, height, width, -1)
  307. image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
  308. height, weight = image_feature.shape[2:]
  309. scaled_shape = [
  310. math.ceil(height / 2),
  311. math.ceil(weight / 2),
  312. ]
  313. image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
  314. image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
  315. if "unpad" in self.mm_patch_merge_type:
  316. image_feature = torch.cat(
  317. (
  318. image_feature,
  319. # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
  320. self.language_model.model.image_newline[None, None].expand(
  321. image_feature.shape[0],
  322. 1,
  323. image_feature.shape[-1],
  324. ),
  325. ),
  326. dim=1,
  327. )
  328. new_image_features.append(image_feature)
  329. image_features = new_image_features
  330. # Fill in the placeholder for the image
  331. extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
  332. extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
  333. prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
  334. pt = 0
  335. for i in range(bs):
  336. if not need_vision[i]:
  337. continue
  338. start_idx = extend_start_loc_cpu[i]
  339. seq_len = extend_seq_lens[i]
  340. prefix_len = prefix_lens_cpu[i]
  341. # Multiple images
  342. for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
  343. if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
  344. continue
  345. if image_offset >= prefix_len + seq_len:
  346. break
  347. tmp_image_feature = image_features[pt][image_idx]
  348. pad_len = tmp_image_feature.shape[0]
  349. input_offset = image_offset - prefix_len
  350. left_idx = start_idx + input_offset
  351. right_idx = left_idx + pad_len
  352. assert right_idx > start_idx
  353. if input_offset < 0:
  354. left_idx = start_idx
  355. tmp_image_feature = tmp_image_feature[-input_offset:]
  356. if right_idx > start_idx + seq_len:
  357. tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
  358. right_idx = start_idx + seq_len
  359. try:
  360. input_embeds[left_idx:right_idx] = tmp_image_feature
  361. except RuntimeError as e:
  362. print(f"RuntimeError in image encoding: {e}")
  363. print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
  364. print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
  365. pt += 1
  366. return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
  367. elif forward_batch.forward_mode.is_decode():
  368. return self.language_model(input_ids, positions, forward_batch)
  369. else:
  370. raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
  371. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  372. projector_weights = {
  373. "model.mm_projector": "multi_modal_mlp",
  374. "model.vision_tower.vision_tower": "vision_tower",
  375. # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
  376. "model.image_newline": "language_model.model.image_newline",
  377. }
  378. params_dict = dict(self.named_parameters())
  379. for name, loaded_weight in weights:
  380. if "projector" in name or "vision_tower" in name or "image_newline" in name:
  381. for weight_name, param_name in projector_weights.items():
  382. if weight_name in name:
  383. name = name.replace(weight_name, param_name)
  384. param = params_dict[name]
  385. weight_loader = getattr(param, "weight_loader", default_weight_loader)
  386. weight_loader(param, loaded_weight)
  387. else:
  388. self.language_model.load_weights([(name, loaded_weight)])
  389. @property
  390. def num_patches_per_side(self):
  391. return self.image_size // self.patch_size
  392. EntryClass = [Mineru2QwenForCausalLM]