_ppocrvl.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # This file is based on https://github.com/Kwai-Keye/Keye/blob/main/keye-vl-8b-preview/modeling_keye.py
  15. # Original header:
  16. # Copyright 2025 The Keye Team and The HuggingFace Inc. team. All rights reserved.
  17. #
  18. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  19. # and OPT implementations in this library. It has been modified from its
  20. # original forms to accommodate minor architectural differences compared
  21. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  22. #
  23. # Licensed under the Apache License, Version 2.0 (the "License");
  24. # you may not use this file except in compliance with the License.
  25. # You may obtain a copy of the License at
  26. #
  27. # http://www.apache.org/licenses/LICENSE-2.0
  28. #
  29. # Unless required by applicable law or agreed to in writing, software
  30. # distributed under the License is distributed on an "AS IS" BASIS,
  31. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  32. # See the License for the specific language governing permissions and
  33. # limitations under the License.
  34. from contextvars import ContextVar
  35. from dataclasses import dataclass
  36. from typing import List, Optional, Tuple, Union
  37. import numpy as np
  38. import paddle
  39. import paddle.nn as nn
  40. from ....common.vlm.generation import GenerationMixin
  41. from ....common.vlm.transformers.model_outputs import (
  42. CausalLMOutputWithCrossAttentions,
  43. ModelOutput,
  44. )
  45. from ._config import PPOCRVLConfig
  46. from ._ernie import Ernie4_5Model, Ernie4_5PretrainedModel
  47. from ._projector import Projector
  48. from ._siglip import SiglipVisionModel
  49. @dataclass
  50. class PPOCRVLCausalLMOutputWithPast(ModelOutput):
  51. loss: Optional[paddle.Tensor] = None
  52. logits: paddle.Tensor = None
  53. past_key_values: Optional[List[paddle.Tensor]] = None
  54. hidden_states: Optional[Tuple[paddle.Tensor]] = None
  55. attentions: Optional[Tuple[paddle.Tensor]] = None
  56. rope_deltas: Optional[paddle.Tensor] = None
  57. class PPOCRVLForConditionalGeneration(Ernie4_5PretrainedModel, GenerationMixin):
  58. _tied_weights_keys = ["lm_head.weight"]
  59. config_class = PPOCRVLConfig
  60. _no_split_modules = ["Ernie4_5DecoderLayer", "SiglipEncoderLayer"]
  61. base_model_prefix = ""
  62. def __init__(self, config):
  63. super().__init__(config)
  64. self.mlp_AR = Projector(config, config.vision_config)
  65. self.visual = SiglipVisionModel(config.vision_config)
  66. self.model = Ernie4_5Model(config)
  67. self.vocab_size = config.vocab_size
  68. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
  69. self.rope_deltas_var = ContextVar("rope_deltas", default=None)
  70. def get_input_embeddings(self):
  71. return self.model.embed_tokens
  72. def set_input_embeddings(self, value):
  73. self.model.embed_tokens = value
  74. def get_output_embeddings(self):
  75. return self.lm_head
  76. def set_output_embeddings(self, new_embeddings):
  77. self.lm_head = new_embeddings
  78. def set_decoder(self, decoder):
  79. self.model = decoder
  80. def get_decoder(self):
  81. return self.model
  82. def get_rope_index(
  83. self,
  84. input_ids: Optional[paddle.Tensor] = None,
  85. image_grid_thw: Optional[paddle.Tensor] = None,
  86. video_grid_thw: Optional[paddle.Tensor] = None,
  87. second_per_grid_ts: Optional[paddle.Tensor] = None,
  88. attention_mask: Optional[paddle.Tensor] = None,
  89. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  90. """
  91. Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
  92. Explanation:
  93. Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
  94. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
  95. Examples:
  96. input_ids: [T T T T T], here T is for text.
  97. temporal position_ids: [0, 1, 2, 3, 4]
  98. height position_ids: [0, 1, 2, 3, 4]
  99. width position_ids: [0, 1, 2, 3, 4]
  100. For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
  101. and 1D rotary position embedding for text part.
  102. Examples:
  103. Temporal (Time): 3 patches, representing different segments of the video in time.
  104. Height: 2 patches, dividing each frame vertically.
  105. Width: 2 patches, dividing each frame horizontally.
  106. We also have some important parameters:
  107. fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
  108. tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
  109. temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
  110. interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
  111. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
  112. vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
  113. vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
  114. vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
  115. text temporal position_ids: [101, 102, 103, 104, 105]
  116. text height position_ids: [101, 102, 103, 104, 105]
  117. text width position_ids: [101, 102, 103, 104, 105]
  118. Here we calculate the text start position_ids as the max vision position_ids plus 1.
  119. Args:
  120. input_ids (`paddle.Tensor` of shape `(batch_size, sequence_length)`):
  121. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  122. it.
  123. image_grid_thw (`paddle.Tensor` of shape `(num_images, 3)`, *optional*):
  124. The temporal, height and width of feature shape of each image in LLM.
  125. video_grid_thw (`paddle.Tensor` of shape `(num_videos, 3)`, *optional*):
  126. The temporal, height and width of feature shape of each video in LLM.
  127. second_per_grid_ts (`paddle.Tensor` of shape `(num_videos)`, *optional*):
  128. The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
  129. attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  130. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  131. - 1 for tokens that are **not masked**,
  132. - 0 for tokens that are **masked**.
  133. Returns:
  134. position_ids (`paddle.Tensor` of shape `(3, batch_size, sequence_length)`)
  135. mrope_position_deltas (`paddle.Tensor` of shape `(batch_size)`)
  136. """
  137. spatial_merge_size = self.config.vision_config.spatial_merge_size
  138. image_token_id = self.config.image_token_id
  139. video_token_id = self.config.video_token_id
  140. vision_start_token_id = self.config.vision_start_token_id
  141. mrope_position_deltas = []
  142. if input_ids is not None and (
  143. image_grid_thw is not None or video_grid_thw is not None
  144. ):
  145. total_input_ids = input_ids
  146. if attention_mask is None:
  147. attention_mask = paddle.ones_like(total_input_ids)
  148. position_ids = paddle.ones(
  149. [3, input_ids.shape[0], input_ids.shape[1]],
  150. dtype=input_ids.dtype,
  151. )
  152. image_index, video_index = 0, 0
  153. for i, input_ids in enumerate(total_input_ids):
  154. input_ids = input_ids[attention_mask[i] == 1]
  155. image_nums, video_nums = 0, 0
  156. vision_start_indices = paddle.nonzero(
  157. input_ids == vision_start_token_id
  158. ).squeeze(1)
  159. vision_tokens = input_ids[vision_start_indices + 1]
  160. image_nums = (vision_tokens == image_token_id).sum()
  161. video_nums = (vision_tokens == video_token_id).sum()
  162. input_tokens = input_ids.tolist()
  163. llm_pos_ids_list: list = []
  164. st = 0
  165. remain_images, remain_videos = image_nums, video_nums
  166. for _ in range(image_nums + video_nums):
  167. if image_token_id in input_tokens and remain_images > 0:
  168. ed_image = input_tokens.index(image_token_id, st)
  169. else:
  170. ed_image = len(input_tokens) + 1
  171. if video_token_id in input_tokens and remain_videos > 0:
  172. ed_video = input_tokens.index(video_token_id, st)
  173. else:
  174. ed_video = len(input_tokens) + 1
  175. if ed_image < ed_video:
  176. t, h, w = (
  177. image_grid_thw[image_index][0],
  178. image_grid_thw[image_index][1],
  179. image_grid_thw[image_index][2],
  180. )
  181. second_per_grid_t = 0
  182. image_index += 1
  183. remain_images -= 1
  184. ed = ed_image
  185. else:
  186. t, h, w = (
  187. video_grid_thw[video_index][0],
  188. video_grid_thw[video_index][1],
  189. video_grid_thw[video_index][2],
  190. )
  191. if second_per_grid_ts is not None:
  192. second_per_grid_t = second_per_grid_ts[video_index]
  193. else:
  194. second_per_grid_t = 1.0
  195. video_index += 1
  196. remain_videos -= 1
  197. ed = ed_video
  198. llm_grid_t, llm_grid_h, llm_grid_w = (
  199. t.item(),
  200. h.item() // spatial_merge_size,
  201. w.item() // spatial_merge_size,
  202. )
  203. text_len = ed - st
  204. st_idx = (
  205. llm_pos_ids_list[-1].max() + 1
  206. if len(llm_pos_ids_list) > 0
  207. else 0
  208. )
  209. llm_pos_ids_list.append(
  210. paddle.arange(text_len).reshape((1, -1)).expand((3, -1))
  211. + st_idx
  212. )
  213. if paddle.is_tensor(second_per_grid_t):
  214. second_per_grid_t = second_per_grid_t.detach().item()
  215. range_tensor = paddle.arange(llm_grid_t).reshape((-1, 1))
  216. expanded_range = range_tensor.expand((-1, llm_grid_h * llm_grid_w))
  217. time_tensor = (
  218. expanded_range
  219. * second_per_grid_t
  220. * self.config.vision_config.tokens_per_second
  221. )
  222. time_tensor_long = time_tensor.astype("int64")
  223. t_index = time_tensor_long.flatten()
  224. h_index = (
  225. paddle.arange(llm_grid_h)
  226. .reshape((1, -1, 1))
  227. .expand((llm_grid_t, -1, llm_grid_w))
  228. .flatten()
  229. )
  230. w_index = (
  231. paddle.arange(llm_grid_w)
  232. .reshape((1, 1, -1))
  233. .expand((llm_grid_t, llm_grid_h, -1))
  234. .flatten()
  235. )
  236. llm_pos_ids_list.append(
  237. paddle.stack([t_index, h_index, w_index]) + text_len + st_idx
  238. )
  239. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  240. if st < len(input_tokens):
  241. st_idx = (
  242. llm_pos_ids_list[-1].max() + 1
  243. if len(llm_pos_ids_list) > 0
  244. else 0
  245. )
  246. text_len = len(input_tokens) - st
  247. llm_pos_ids_list.append(
  248. paddle.arange(text_len).reshape((1, -1)).expand((3, -1))
  249. + st_idx
  250. )
  251. llm_positions = paddle.concat(llm_pos_ids_list, axis=1).reshape((3, -1))
  252. position_ids[..., i, attention_mask[i] == 1] = llm_positions
  253. mrope_position_deltas.append(
  254. llm_positions.max() + 1 - len(total_input_ids[i])
  255. )
  256. mrope_position_deltas = paddle.to_tensor(mrope_position_deltas).unsqueeze(1)
  257. return position_ids, mrope_position_deltas
  258. else:
  259. if attention_mask is not None:
  260. position_ids = attention_mask.long().cumsum(-1) - 1
  261. position_ids.masked_fill_(attention_mask == 0, 1)
  262. position_ids = position_ids.unsqueeze(0).expand((3, -1, -1))
  263. max_position_ids = position_ids.max(0, keepdim=False)[0].max(
  264. -1, keepdim=True
  265. )[0]
  266. mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
  267. else:
  268. position_ids = (
  269. paddle.arange(input_ids.shape[1])
  270. .reshape((1, 1, -1))
  271. .expand((3, input_ids.shape[0], -1))
  272. )
  273. mrope_position_deltas = paddle.zeros(
  274. [input_ids.shape[0], 1],
  275. dtype=input_ids.dtype,
  276. )
  277. return position_ids, mrope_position_deltas
  278. def prepare_attention_mask_for_generation(
  279. self, input_ids, pad_token_id, eos_token_id
  280. ):
  281. """Avoid using attention_mask with flash_attn on generation."""
  282. if self.config.use_flash_attention:
  283. return None
  284. return super().prepare_attention_mask_for_generation(
  285. input_ids, pad_token_id, eos_token_id
  286. )
  287. def prepare_inputs_for_generation(
  288. self,
  289. input_ids,
  290. use_cache=False,
  291. past_key_values=None,
  292. inputs_embeds=None,
  293. pixel_values=None,
  294. pixel_values_videos=None,
  295. position_ids=None,
  296. **kwargs,
  297. ):
  298. if past_key_values:
  299. input_ids = input_ids[:, -1:]
  300. pixel_values = None
  301. pixel_values_videos = None
  302. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  303. if inputs_embeds is not None and past_key_values is None:
  304. model_inputs = {"inputs_embeds": inputs_embeds}
  305. else:
  306. model_inputs = {"input_ids": input_ids}
  307. model_inputs.update(
  308. {
  309. "past_key_values": past_key_values,
  310. "use_cache": use_cache,
  311. "pixel_values": pixel_values,
  312. "pixel_values_videos": pixel_values_videos,
  313. "position_ids": None,
  314. **kwargs,
  315. }
  316. )
  317. return model_inputs
  318. def update_model_kwargs_for_generation(
  319. self, outputs, model_kwargs, is_encoder_decoder=False
  320. ):
  321. """
  322. Updates model kwargs for generation.
  323. Args:
  324. outputs (Any): Model outputs.
  325. model_kwargs (dict): Current model kwargs.
  326. is_encoder_decoder (bool): Whether using encoder-decoder architecture.
  327. Returns:
  328. dict: Updated model kwargs.
  329. """
  330. # update cache
  331. if (
  332. isinstance(outputs, tuple)
  333. and len(outputs) > 1
  334. and not isinstance(outputs[1], paddle.Tensor)
  335. ):
  336. model_kwargs["past_key_values"] = outputs[1]
  337. if (
  338. isinstance(outputs, CausalLMOutputWithCrossAttentions)
  339. and "past_key_values" in outputs
  340. ):
  341. model_kwargs["past_key_values"] = outputs.past_key_values
  342. if (
  343. not is_encoder_decoder
  344. and model_kwargs.get("attention_mask", None) is not None
  345. ):
  346. # update attention mask
  347. attention_mask = model_kwargs["attention_mask"]
  348. model_kwargs["attention_mask"] = paddle.concat(
  349. [
  350. attention_mask,
  351. paddle.ones(
  352. [attention_mask.shape[0], 1], dtype=attention_mask.dtype
  353. ),
  354. ],
  355. axis=-1,
  356. )
  357. return model_kwargs
  358. def get_transpose_weight_keys(self):
  359. t_layers = [
  360. "out_proj",
  361. "q_proj",
  362. "k_proj",
  363. "v_proj",
  364. "lm_head",
  365. "gate_proj",
  366. "up_proj",
  367. "down_proj",
  368. "o_proj",
  369. "lm_head",
  370. "linear_1",
  371. "linear_2",
  372. "fc",
  373. "in_proj",
  374. ]
  375. keys = []
  376. for key, _ in self.get_hf_state_dict().items():
  377. for t_layer in t_layers:
  378. if t_layer in key and key.endswith("weight"):
  379. keys.append(key)
  380. return keys
  381. def get_hf_state_dict(self, *args, **kwargs):
  382. def _merge_attention_weights(
  383. q_weight=None,
  384. k_weight=None,
  385. v_weight=None,
  386. q_bias=None,
  387. k_bias=None,
  388. v_bias=None,
  389. ):
  390. if q_weight is not None and k_weight is not None and v_weight is not None:
  391. return paddle.concat([q_weight, k_weight, v_weight], axis=1)
  392. elif q_bias is not None and k_bias is not None and v_bias is not None:
  393. return paddle.concat([q_bias, k_bias, v_bias], axis=0)
  394. else:
  395. raise ValueError
  396. def _convert_to_hf_state_dict(current_state_dict):
  397. hf_state_dict = {}
  398. for key in list(current_state_dict.keys()):
  399. if "up_gate_proj" in key:
  400. combined_weights = current_state_dict[key]
  401. split_size = combined_weights.shape[-1] // 2
  402. gate_proj = combined_weights[..., :split_size]
  403. up_proj = combined_weights[..., split_size:]
  404. hf_state_dict[key.replace("up_gate_proj", "gate_proj")] = gate_proj
  405. hf_state_dict[key.replace("up_gate_proj", "up_proj")] = up_proj
  406. continue
  407. if "qkv_proj" in key and ("weight" in key or "bias" in key):
  408. combined_weights = current_state_dict[key]
  409. if getattr(self.config, "head_dim", None) is None:
  410. head_dim = self.hidden_size // self.num_heads
  411. else:
  412. head_dim = self.config.head_dim
  413. num_heads = self.config.num_attention_heads
  414. num_kv_heads = self.config.num_key_value_heads
  415. q_proj, k_proj, v_proj = paddle.split(
  416. combined_weights,
  417. [
  418. num_heads * head_dim,
  419. num_kv_heads * head_dim,
  420. num_kv_heads * head_dim,
  421. ],
  422. axis=-1,
  423. )
  424. if "weight" in key:
  425. hf_state_dict[
  426. key.replace("qkv_proj.weight", "q_proj.weight")
  427. ] = q_proj
  428. hf_state_dict[
  429. key.replace("qkv_proj.weight", "k_proj.weight")
  430. ] = k_proj
  431. hf_state_dict[
  432. key.replace("qkv_proj.weight", "v_proj.weight")
  433. ] = v_proj
  434. else: # bias
  435. hf_state_dict[key.replace("qkv_proj.bias", "q_proj.bias")] = (
  436. q_proj
  437. )
  438. hf_state_dict[key.replace("qkv_proj.bias", "k_proj.bias")] = (
  439. k_proj
  440. )
  441. hf_state_dict[key.replace("qkv_proj.bias", "v_proj.bias")] = (
  442. v_proj
  443. )
  444. continue
  445. if "up_gate_proj" not in key and "qkv_proj" not in key:
  446. hf_state_dict[key] = current_state_dict[key]
  447. new_hf_state_dict = {}
  448. keys_to_remove = set()
  449. for key, value in hf_state_dict.items():
  450. if "head.attention" in key and "out_proj" not in key:
  451. if "weight" in key:
  452. q_key = key
  453. k_key = key.replace("q_proj", "k_proj")
  454. v_key = key.replace("q_proj", "v_proj")
  455. if (
  456. q_key in hf_state_dict
  457. and k_key in hf_state_dict
  458. and v_key in hf_state_dict
  459. ):
  460. merged_weights = _merge_attention_weights(
  461. q_weight=hf_state_dict[q_key],
  462. k_weight=hf_state_dict[k_key],
  463. v_weight=hf_state_dict[v_key],
  464. )
  465. new_key = key.replace("q_proj.weight", "in_proj_weight")
  466. new_hf_state_dict[new_key] = merged_weights
  467. keys_to_remove.update([q_key, k_key, v_key])
  468. elif "bias" in key:
  469. q_key = key
  470. k_key = key.replace("q_proj", "k_proj")
  471. v_key = key.replace("q_proj", "v_proj")
  472. if (
  473. q_key in hf_state_dict
  474. and k_key in hf_state_dict
  475. and v_key in hf_state_dict
  476. ):
  477. merged_bias = _merge_attention_weights(
  478. q_bias=hf_state_dict[q_key],
  479. k_bias=hf_state_dict[k_key],
  480. v_bias=hf_state_dict[v_key],
  481. )
  482. new_key = key.replace("q_proj.bias", "in_proj_bias")
  483. new_hf_state_dict[new_key] = merged_bias
  484. keys_to_remove.update([q_key, k_key, v_key])
  485. else:
  486. new_hf_state_dict[key] = value
  487. for key in keys_to_remove:
  488. if key in new_hf_state_dict:
  489. del new_hf_state_dict[key]
  490. return new_hf_state_dict
  491. current_state_dict = self.state_dict(*args, **kwargs)
  492. hf_state_dict = _convert_to_hf_state_dict(current_state_dict)
  493. return hf_state_dict
  494. def set_hf_state_dict(self, state_dict, *args, **kwargs):
  495. def _split_attention_weights(weight=None, bias=None):
  496. if weight is not None:
  497. split_size = weight.shape[1] // 3
  498. q_weight = weight[:, :split_size]
  499. k_weight = weight[:, split_size : 2 * split_size]
  500. v_weight = weight[:, 2 * split_size :]
  501. return q_weight, k_weight, v_weight
  502. elif bias is not None:
  503. split_size = bias.shape[0] // 3
  504. q_bias = bias[:split_size]
  505. k_bias = bias[split_size : 2 * split_size]
  506. v_bias = bias[2 * split_size :]
  507. return q_bias, k_bias, v_bias
  508. def _convert_state_dict(old_state_dict):
  509. new_state_dict = {}
  510. for key, value in old_state_dict.items():
  511. if "head.attention.in_proj" in key:
  512. if key.endswith("weight"):
  513. q_w, k_w, v_w = _split_attention_weights(weight=value)
  514. new_state_dict[
  515. key.replace("in_proj_weight", "q_proj.weight")
  516. ] = q_w
  517. new_state_dict[
  518. key.replace("in_proj_weight", "k_proj.weight")
  519. ] = k_w
  520. new_state_dict[
  521. key.replace("in_proj_weight", "v_proj.weight")
  522. ] = v_w
  523. elif key.endswith("bias"):
  524. q_b, k_b, v_b = _split_attention_weights(bias=value)
  525. new_state_dict[key.replace("in_proj_bias", "q_proj.bias")] = q_b
  526. new_state_dict[key.replace("in_proj_bias", "k_proj.bias")] = k_b
  527. new_state_dict[key.replace("in_proj_bias", "v_proj.bias")] = v_b
  528. else:
  529. raise ValueError(f"Unexpected key: {key}")
  530. else:
  531. new_state_dict[key] = value
  532. for key in list(new_state_dict.keys()):
  533. if key.startswith("model."):
  534. if "mlp.gate_proj." in key:
  535. gate_proj = new_state_dict.pop(key)
  536. up_proj = new_state_dict.pop(
  537. key.replace("gate_proj", "up_proj")
  538. )
  539. new_state_dict[key.replace("gate_proj", "up_gate_proj")] = (
  540. paddle.concat([gate_proj, up_proj], axis=-1)
  541. )
  542. if "self_attn.q_proj" in key:
  543. q_proj = new_state_dict.pop(key)
  544. k_proj = new_state_dict.pop(key.replace("q_proj", "k_proj"))
  545. v_proj = new_state_dict.pop(key.replace("q_proj", "v_proj"))
  546. new_state_dict[key.replace("q_proj", "qkv_proj")] = (
  547. paddle.concat([q_proj, k_proj, v_proj], axis=-1)
  548. )
  549. return new_state_dict
  550. state_dict = _convert_state_dict(state_dict)
  551. std_state_dict = self.state_dict()
  552. assert std_state_dict.keys() == state_dict.keys()
  553. for key in std_state_dict:
  554. v1 = std_state_dict[key]
  555. state_dict[key] = state_dict[key].to(v1.place)
  556. return self.set_state_dict(state_dict, *args, **kwargs)
  557. def forward(
  558. self,
  559. input_ids: paddle.Tensor = None,
  560. attention_mask: Optional[paddle.Tensor] = None,
  561. position_ids: Optional[paddle.Tensor] = None,
  562. past_key_values: Optional[List[paddle.Tensor]] = None,
  563. inputs_embeds: Optional[paddle.Tensor] = None,
  564. labels: Optional[paddle.Tensor] = None,
  565. use_cache: Optional[bool] = None,
  566. output_attentions: Optional[bool] = None,
  567. output_hidden_states: Optional[bool] = None,
  568. return_dict: Optional[bool] = None,
  569. pixel_values: Optional[paddle.Tensor] = None,
  570. pixel_values_videos: Optional[paddle.Tensor] = None,
  571. image_grid_thw: Optional[paddle.Tensor] = None,
  572. video_grid_thw: Optional[paddle.Tensor] = None,
  573. rope_deltas: Optional[paddle.Tensor] = None,
  574. second_per_grid_ts: Optional[paddle.Tensor] = None,
  575. **kwargs,
  576. ) -> Union[Tuple, PPOCRVLCausalLMOutputWithPast]:
  577. output_attentions = (
  578. output_attentions
  579. if output_attentions is not None
  580. else self.config.output_attentions
  581. )
  582. output_hidden_states = (
  583. output_hidden_states
  584. if output_hidden_states is not None
  585. else self.config.output_hidden_states
  586. )
  587. return_dict = (
  588. return_dict if return_dict is not None else self.config.use_return_dict
  589. )
  590. curr_rope_deltas = self.rope_deltas_var.get()
  591. if inputs_embeds is None:
  592. if input_ids.shape[0] != 1:
  593. raise NotImplementedError
  594. inputs_embeds = self.model.embed_tokens(input_ids)
  595. if pixel_values is not None:
  596. pixel_values = pixel_values.astype(inputs_embeds.dtype)
  597. pixel_values = pixel_values.unsqueeze(0)
  598. siglip_position_ids = list()
  599. image_grid_hws = list()
  600. sample_indices = list()
  601. cu_seqlens = [0]
  602. pro = 0
  603. for idx, thw in enumerate(image_grid_thw):
  604. thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
  605. numel = np.prod(thw_tuple)
  606. image_grid_hws.append(thw_tuple)
  607. image_position_ids = paddle.arange(numel) % np.prod(thw_tuple[1:])
  608. siglip_position_ids.append(image_position_ids)
  609. sample_indices.append(
  610. paddle.full((numel,), idx, dtype=paddle.int64)
  611. )
  612. cu_seqlens.append(cu_seqlens[-1] + numel)
  613. siglip_position_ids = paddle.concat(siglip_position_ids, axis=0)
  614. cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32)
  615. sample_indices = paddle.concat(sample_indices, axis=0)
  616. vision_outputs = self.visual(
  617. pixel_values=pixel_values,
  618. image_grid_thw=image_grid_hws,
  619. position_ids=siglip_position_ids,
  620. vision_return_embed_list=True,
  621. interpolate_pos_encoding=True,
  622. sample_indices=sample_indices,
  623. cu_seqlens=cu_seqlens,
  624. return_pooler_output=False,
  625. use_rope=True,
  626. window_size=-1,
  627. )
  628. image_embeds = vision_outputs.last_hidden_state
  629. image_embeds = self.mlp_AR(image_embeds, image_grid_thw)
  630. n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
  631. image_embeds = paddle.concat(image_embeds, axis=0)
  632. n_image_features = image_embeds.shape[0]
  633. if n_image_tokens != n_image_features:
  634. raise ValueError(
  635. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  636. )
  637. mask = input_ids == self.config.image_token_id
  638. mask_unsqueezed = mask.unsqueeze(-1)
  639. mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
  640. image_mask = mask_expanded
  641. image_embeds = image_embeds.astype(inputs_embeds.dtype)
  642. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  643. else:
  644. if inputs_embeds.shape[0] != 1:
  645. raise NotImplementedError
  646. if attention_mask is not None and attention_mask.dtype != paddle.bool:
  647. attention_mask = paddle.cast(attention_mask, paddle.bool)
  648. # position_ids = None
  649. # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
  650. if position_ids is None and (
  651. attention_mask is None or attention_mask.ndim == 2
  652. ):
  653. # calculate RoPE index once per generation in the pre-fill stage only
  654. if curr_rope_deltas is None or (
  655. past_key_values is None or past_key_values[0] is None
  656. ):
  657. position_ids, rope_deltas = self.get_rope_index(
  658. input_ids,
  659. image_grid_thw,
  660. video_grid_thw,
  661. second_per_grid_ts,
  662. attention_mask,
  663. )
  664. self.rope_deltas_var.set(rope_deltas)
  665. # then use the prev pre-calculated rope-deltas to get the correct position ids
  666. else:
  667. batch_size, seq_length, _ = inputs_embeds.shape
  668. delta = (
  669. (past_key_values[0][0].shape[1] + curr_rope_deltas)
  670. if past_key_values is not None and past_key_values[0] is not None
  671. else 0
  672. )
  673. position_ids = paddle.arange(seq_length)
  674. position_ids = position_ids.reshape((1, -1)).expand((batch_size, -1))
  675. if (
  676. past_key_values is not None and past_key_values[0] is not None
  677. ): # otherwise `deltas` is an int `0`
  678. delta = delta.repeat_interleave(
  679. batch_size // delta.shape[0], axis=0
  680. )
  681. position_ids = position_ids.add(delta)
  682. position_ids = position_ids.unsqueeze(0).expand((3, -1, -1))
  683. outputs = self.model(
  684. input_ids=None,
  685. position_ids=position_ids,
  686. attention_mask=attention_mask,
  687. past_key_values=past_key_values,
  688. inputs_embeds=inputs_embeds,
  689. use_cache=use_cache,
  690. output_attentions=output_attentions,
  691. output_hidden_states=output_hidden_states,
  692. return_dict=return_dict,
  693. **kwargs,
  694. )
  695. hidden_states = outputs[0]
  696. logits = self.lm_head(hidden_states)
  697. loss = None
  698. if labels is not None:
  699. # Upcast to float if we need to compute the loss to avoid potential precision issues
  700. logits = logits.astype("float32")
  701. # Shift so that tokens < n predict n
  702. shift_logits = logits[..., :-1, :].contiguous()
  703. shift_labels = labels[..., 1:].contiguous()
  704. # Flatten the tokens
  705. loss_fct = paddle.nn.CrossEntropyLoss()
  706. shift_logits = shift_logits.reshape((-1, self.config.vocab_size))
  707. shift_labels = shift_labels.reshape((-1,))
  708. loss = loss_fct(shift_logits, shift_labels)
  709. if not return_dict:
  710. output = (logits,) + outputs[1:]
  711. return (loss,) + output if loss is not None else output
  712. return PPOCRVLCausalLMOutputWithPast(
  713. loss=loss,
  714. logits=logits,
  715. past_key_values=outputs.past_key_values,
  716. hidden_states=outputs.hidden_states,
  717. attentions=outputs.attentions,
  718. rope_deltas=curr_rope_deltas,
  719. )
  720. def generate(self, inputs, **kwargs):
  721. gen_kwargs = {
  722. "max_new_tokens": kwargs.get("max_new_tokens", 8192),
  723. "use_cache": kwargs.get("use_cache", True),
  724. }
  725. gen_kwargs = {**inputs, **gen_kwargs}
  726. with paddle.no_grad():
  727. generated_ids = super().generate(**gen_kwargs)
  728. return generated_ids
  729. def _get_image_nums_and_video_nums(
  730. self,
  731. input_ids: Optional[paddle.Tensor],
  732. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  733. """
  734. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
  735. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
  736. Args:
  737. input_ids (`paddle.Tensor` of shape `(batch_size, sequence_length)`):
  738. Indices of input sequence tokens in the vocabulary.
  739. Returns:
  740. image_nums (`paddle.Tensor` of shape `(batch_size, num_images_sample)`)
  741. video_nums (`paddle.Tensor` of shape `(batch_size, num_videos_sample)`)
  742. """
  743. image_token_id = self.config.image_token_id
  744. video_token_id = self.config.video_token_id
  745. vision_start_token_id = self.config.vision_start_token_id
  746. vision_start_mask = input_ids == vision_start_token_id
  747. vision_first_mask = paddle.roll(vision_start_mask, shifts=1, axis=1)
  748. image_mask = input_ids == image_token_id
  749. video_mask = input_ids == video_token_id
  750. image_nums = paddle.sum(vision_first_mask & image_mask, axis=1)
  751. video_nums = paddle.sum(vision_first_mask & video_mask, axis=1)
  752. return image_nums, video_nums