qwen2_vl.py 103 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. from dataclasses import dataclass
  17. from functools import partial
  18. from typing import Any, Dict, List, Optional, Tuple, Union
  19. import paddle
  20. import paddle.distributed.fleet.meta_parallel as mpu
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from paddle import Tensor
  24. from paddle.distributed import fleet
  25. from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
  26. from paddle.distributed.fleet.utils import recompute
  27. from .....utils import logging
  28. from ...common.vlm.activations import ACT2FN
  29. from ...common.vlm.bert_padding import index_first_axis, pad_input, unpad_input
  30. from ...common.vlm.flash_attn_utils import has_flash_attn_func
  31. from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
  32. from ...common.vlm.transformers.model_outputs import (
  33. BaseModelOutputWithPast,
  34. ModelOutput,
  35. )
  36. flash_attn_func, flash_attn_varlen_func = has_flash_attn_func()
  37. _IS_NPU = "npu" in paddle.get_device()
  38. Linear = nn.Linear
  39. ColumnParallelLinear = mpu.ColumnParallelLinear
  40. RowParallelLinear = mpu.RowParallelLinear
  41. class Qwen2VLVisionConfig(PretrainedConfig):
  42. model_type = "qwen2_vl"
  43. def __init__(
  44. self,
  45. depth=32,
  46. embed_dim=1280,
  47. hidden_size=3584,
  48. hidden_act="quick_gelu",
  49. mlp_ratio=4,
  50. num_heads=16,
  51. in_channels=3,
  52. patch_size=14,
  53. spatial_merge_size=2,
  54. temporal_patch_size=2,
  55. attn_implementation="eager", # new added
  56. **kwargs,
  57. ):
  58. super().__init__(**kwargs)
  59. self.depth = depth
  60. self.embed_dim = embed_dim
  61. self.hidden_size = hidden_size
  62. self.hidden_act = hidden_act
  63. self.mlp_ratio = mlp_ratio
  64. self.num_heads = num_heads
  65. self.in_channels = in_channels
  66. self.patch_size = patch_size
  67. self.spatial_merge_size = spatial_merge_size
  68. self.temporal_patch_size = temporal_patch_size
  69. self.attn_implementation = attn_implementation
  70. @classmethod
  71. def from_pretrained(
  72. cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
  73. ) -> "PretrainedConfig":
  74. config_dict, kwargs = cls.get_config_dict(
  75. pretrained_model_name_or_path, **kwargs
  76. )
  77. if config_dict.get("model_type") == "qwen2_vl":
  78. config_dict = config_dict["vision_config"]
  79. if (
  80. "model_type" in config_dict
  81. and hasattr(cls, "model_type")
  82. and config_dict["model_type"] != cls.model_type
  83. ):
  84. logging.warning(
  85. f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
  86. f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
  87. )
  88. return cls.from_dict(config_dict, **kwargs)
  89. class Qwen2VLConfig(PretrainedConfig):
  90. r"""
  91. This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
  92. Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
  93. with the defaults will yield a similar configuration to that of
  94. Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
  95. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  96. documentation from [`PretrainedConfig`] for more information.
  97. Args:
  98. vocab_size (`int`, *optional*, defaults to 152064):
  99. Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
  100. `inputs_ids` passed when calling [`Qwen2VLModel`]
  101. hidden_size (`int`, *optional*, defaults to 8192):
  102. Dimension of the hidden representations.
  103. intermediate_size (`int`, *optional*, defaults to 29568):
  104. Dimension of the MLP representations.
  105. num_hidden_layers (`int`, *optional*, defaults to 80):
  106. Number of hidden layers in the Transformer encoder.
  107. num_attention_heads (`int`, *optional*, defaults to 64):
  108. Number of attention heads for each attention layer in the Transformer encoder.
  109. num_key_value_heads (`int`, *optional*, defaults to 8):
  110. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  111. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  112. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  113. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  114. by meanpooling all the original heads within that group. For more details checkout [this
  115. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
  116. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  117. The non-linear activation function (function or string) in the decoder.
  118. max_position_embeddings (`int`, *optional*, defaults to 32768):
  119. The maximum sequence length that this model might ever be used with.
  120. initializer_range (`float`, *optional*, defaults to 0.02):
  121. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  122. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  123. The epsilon used by the rms normalization layers.
  124. use_cache (`bool`, *optional*, defaults to `True`):
  125. Whether or not the model should return the last key/values attentions (not used by all models). Only
  126. relevant if `config.is_decoder=True`.
  127. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  128. Whether the model's input and output word embeddings should be tied.
  129. rope_theta (`float`, *optional*, defaults to 1000000.0):
  130. The base period of the RoPE embeddings.
  131. use_sliding_window (`bool`, *optional*, defaults to `False`):
  132. Whether to use sliding window attention.
  133. sliding_window (`int`, *optional*, defaults to 4096):
  134. Sliding window attention (SWA) window size. If not specified, will default to `4096`.
  135. max_window_layers (`int`, *optional*, defaults to 80):
  136. The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
  137. attention_dropout (`float`, *optional*, defaults to 0.0):
  138. The dropout ratio for the attention probabilities.
  139. vision_config (`Dict`, *optional*):
  140. The config for the visual encoder initialization.
  141. rope_scaling (`Dict`, *optional*):
  142. Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
  143. strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
  144. `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
  145. `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
  146. these scaling strategies behave:
  147. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
  148. experimental feature, subject to breaking API changes in future versions.
  149. """
  150. model_type = "qwen2_vl"
  151. keys_to_ignore_at_inference = ["past_key_values"]
  152. def __init__(
  153. self,
  154. vocab_size=152064,
  155. hidden_size=8192,
  156. intermediate_size=29568,
  157. num_hidden_layers=80,
  158. num_attention_heads=64,
  159. num_key_value_heads=8,
  160. hidden_act="silu",
  161. max_position_embeddings=32768,
  162. initializer_range=0.02,
  163. rms_norm_eps=1e-05,
  164. use_cache=True,
  165. tie_word_embeddings=False,
  166. rope_theta=1000000.0,
  167. use_sliding_window=False,
  168. sliding_window=4096,
  169. max_window_layers=80,
  170. attention_dropout=0.0,
  171. vision_config=None,
  172. rope_scaling=None,
  173. **kwargs,
  174. ):
  175. if isinstance(vision_config, dict):
  176. self.vision_config = Qwen2VLVisionConfig(**vision_config)
  177. elif vision_config is None:
  178. self.vision_config = Qwen2VLVisionConfig()
  179. self.vocab_size = vocab_size
  180. self.max_position_embeddings = max_position_embeddings
  181. self.hidden_size = hidden_size
  182. self.intermediate_size = intermediate_size
  183. self.num_hidden_layers = num_hidden_layers
  184. self.num_attention_heads = num_attention_heads
  185. self.use_sliding_window = use_sliding_window
  186. self.sliding_window = sliding_window
  187. self.max_window_layers = max_window_layers
  188. if num_key_value_heads is None:
  189. num_key_value_heads = num_attention_heads
  190. self.num_key_value_heads = num_key_value_heads
  191. self.hidden_act = hidden_act
  192. self.initializer_range = initializer_range
  193. self.rms_norm_eps = rms_norm_eps
  194. self.use_cache = use_cache
  195. self.rope_theta = rope_theta
  196. self.attention_dropout = attention_dropout
  197. self.rope_scaling = rope_scaling
  198. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  199. def get_triangle_upper_mask(x, mask=None):
  200. if mask is not None:
  201. return mask
  202. shape = x.shape
  203. shape[1] = 1
  204. mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
  205. mask = paddle.triu(mask, diagonal=1)
  206. mask.stop_gradient = True
  207. return mask
  208. def parallel_matmul(
  209. x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
  210. ):
  211. is_fleet_init = True
  212. tensor_parallel_degree = 1
  213. try:
  214. hcg = fleet.get_hybrid_communicate_group()
  215. model_parallel_group = hcg.get_model_parallel_group()
  216. tensor_parallel_degree = hcg.get_model_parallel_world_size()
  217. except:
  218. is_fleet_init = False
  219. if paddle.in_dynamic_mode():
  220. y_is_distributed = y.is_distributed
  221. else:
  222. y_is_distributed = tensor_parallel_degree > 1
  223. if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
  224. input_parallel = paddle.distributed.collective._c_identity(
  225. x, group=model_parallel_group
  226. )
  227. logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
  228. if tensor_parallel_output:
  229. return logits
  230. return paddle.distributed.collective._c_concat(
  231. logits, group=model_parallel_group
  232. )
  233. else:
  234. logits = paddle.matmul(x, y, transpose_y=transpose_y)
  235. return logits
  236. def _compute_default_rope_parameters(
  237. config: Optional[PretrainedConfig] = None,
  238. device: Optional["paddle.device"] = None,
  239. seq_len: Optional[int] = None,
  240. **rope_kwargs,
  241. ) -> Tuple["paddle.Tensor", float]:
  242. """
  243. Computes the inverse frequencies according to the original RoPE implementation
  244. Args:
  245. config ([`~transformers.PretrainedConfig`]):
  246. The model configuration.
  247. device (`paddle.device`):
  248. The device to use for initialization of the inverse frequencies.
  249. seq_len (`int`, *optional*):
  250. The current sequence length. Unused for this type of RoPE.
  251. rope_kwargs (`Dict`, *optional*):
  252. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  253. Returns:
  254. Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  255. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  256. """
  257. if config is not None and len(rope_kwargs) > 0:
  258. raise ValueError(
  259. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  260. f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  261. )
  262. if len(rope_kwargs) > 0:
  263. base = rope_kwargs["base"]
  264. dim = rope_kwargs["dim"]
  265. elif config is not None:
  266. base = config.rope_theta
  267. partial_rotary_factor = (
  268. config.partial_rotary_factor
  269. if hasattr(config, "partial_rotary_factor")
  270. else 1.0
  271. )
  272. head_dim = getattr(
  273. config, "head_dim", config.hidden_size // config.num_attention_heads
  274. )
  275. dim = int(head_dim * partial_rotary_factor)
  276. attention_factor = 1.0
  277. inv_freq = 1.0 / (
  278. base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)
  279. )
  280. return inv_freq, attention_factor
  281. ROPE_INIT_FUNCTIONS = {
  282. "default": _compute_default_rope_parameters,
  283. }
  284. def _get_unpad_data(attention_mask):
  285. seqlens_in_batch = attention_mask.sum(axis=-1, dtype="int32")
  286. indices = paddle.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  287. max_seqlen_in_batch = seqlens_in_batch.max().item() # [2, 1, 1323]
  288. cu_seqlens = F.pad(
  289. paddle.cumsum(seqlens_in_batch, axis=0), (1, 0), data_format="NCL"
  290. )
  291. return (
  292. indices,
  293. cu_seqlens,
  294. max_seqlen_in_batch,
  295. )
  296. def is_casual_mask(attention_mask):
  297. """
  298. Upper triangular of attention_mask equals to attention_mask is casual
  299. """
  300. return (paddle.triu(attention_mask) == attention_mask).all().item()
  301. def _make_causal_mask(input_ids_shape, past_key_values_length):
  302. """
  303. Make causal mask used for self-attention
  304. """
  305. batch_size, target_length = input_ids_shape
  306. mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
  307. if past_key_values_length > 0:
  308. mask = paddle.concat(
  309. [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
  310. axis=-1,
  311. )
  312. return mask[None, None, :, :].expand(
  313. [batch_size, 1, target_length, target_length + past_key_values_length]
  314. )
  315. def _expand_2d_mask(mask, dtype, tgt_length):
  316. """
  317. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
  318. """
  319. batch_size, src_length = mask.shape[0], mask.shape[-1]
  320. tgt_length = tgt_length if tgt_length is not None else src_length
  321. mask = mask[:, None, None, :].astype("bool")
  322. mask.stop_gradient = True
  323. expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
  324. return expanded_mask
  325. @dataclass
  326. class Qwen2VLCausalLMOutputWithPast(ModelOutput):
  327. """
  328. Base class for Qwen2VL causal language model (or autoregressive) outputs.
  329. Args:
  330. loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  331. Language modeling loss (for next-token prediction).
  332. logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  333. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  334. past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  335. Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  336. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  337. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  338. `past_key_values` input) to speed up sequential decoding.
  339. hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  340. Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  341. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  342. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  343. attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  344. Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  345. sequence_length)`.
  346. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  347. heads.
  348. rope_deltas (`paddle.Tensor` of shape `(batch_size, )`, *optional*):
  349. The rope index difference between sequence length and multimodal rope.
  350. """
  351. loss: Optional[paddle.Tensor] = None
  352. logits: paddle.Tensor = None
  353. past_key_values: Optional[List[paddle.Tensor]] = None
  354. hidden_states: Optional[Tuple[paddle.Tensor]] = None
  355. attentions: Optional[Tuple[paddle.Tensor]] = None
  356. rope_deltas: Optional[paddle.Tensor] = None
  357. class Qwen2VLRotaryEmbedding(nn.Layer):
  358. def __init__(
  359. self,
  360. dim=None,
  361. max_position_embeddings=2048,
  362. base=10000,
  363. device=None,
  364. scaling_factor=1.0,
  365. rope_type="default",
  366. config: Optional[Qwen2VLConfig] = None,
  367. ):
  368. super().__init__()
  369. self.rope_kwargs = {}
  370. if config is None:
  371. self.rope_kwargs = {
  372. "rope_type": rope_type,
  373. "factor": scaling_factor,
  374. "dim": dim,
  375. "base": base,
  376. "max_position_embeddings": max_position_embeddings,
  377. }
  378. self.rope_type = rope_type
  379. self.max_seq_len_cached = max_position_embeddings
  380. self.original_max_seq_len = max_position_embeddings
  381. else:
  382. # BC: "rope_type" was originally "type"
  383. if config.rope_scaling is not None:
  384. self.rope_type = config.rope_scaling.get(
  385. "rope_type", config.rope_scaling.get("type")
  386. )
  387. else:
  388. self.rope_type = "default"
  389. self.max_seq_len_cached = config.max_position_embeddings
  390. self.original_max_seq_len = config.max_position_embeddings
  391. self.config = config
  392. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  393. self.inv_freq, self.attention_scaling = self.rope_init_fn(
  394. self.config, device, **self.rope_kwargs
  395. )
  396. self.original_inv_freq = self.inv_freq
  397. self._set_cos_sin_cache(seq_len=max_position_embeddings)
  398. def _set_cos_sin_cache(self, seq_len):
  399. self.max_seq_len_cached = seq_len
  400. t = paddle.arange(seq_len, dtype="float32")
  401. freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
  402. emb = paddle.concat([freqs, freqs], axis=-1)
  403. self.cos_cached = emb.cos()
  404. self.sin_cached = emb.sin()
  405. def _dynamic_frequency_update(self, position_ids, device):
  406. """
  407. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  408. 1 - growing beyond the cached sequence length (allow scaling)
  409. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  410. """
  411. seq_len = paddle.max(position_ids) + 1
  412. if seq_len > self.max_seq_len_cached: # growth
  413. inv_freq, self.attention_scaling = self.rope_init_fn(
  414. self.config, device, seq_len=seq_len, **self.rope_kwargs
  415. )
  416. self.inv_freq = inv_freq
  417. self.max_seq_len_cached = seq_len
  418. if (
  419. seq_len < self.original_max_seq_len
  420. and self.max_seq_len_cached > self.original_max_seq_len
  421. ): # reset
  422. self.inv_freq = self.original_inv_freq
  423. self.max_seq_len_cached = self.original_max_seq_len
  424. @paddle.no_grad()
  425. def forward(self, x, position_ids):
  426. if "dynamic" in self.rope_type:
  427. self._dynamic_frequency_update(position_ids, device=x.device)
  428. inv_freq_expanded = (
  429. self.inv_freq[None, None, :, None]
  430. .astype("float32")
  431. .expand([3, position_ids.shape[1], -1, 1])
  432. )
  433. position_ids_expanded = position_ids[:, :, None, :].astype("float32")
  434. device_type = paddle.get_device()
  435. device_type = (
  436. device_type
  437. if isinstance(device_type, str) and device_type != "mps"
  438. else "cpu"
  439. )
  440. with paddle.amp.auto_cast():
  441. freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded)
  442. freqs = freqs.transpose([0, 1, 3, 2])
  443. emb = paddle.concat((freqs, freqs), axis=-1)
  444. cos = emb.cos()
  445. sin = emb.sin()
  446. cos = cos * self.attention_scaling
  447. sin = sin * self.attention_scaling
  448. return cos.astype(x.dtype), sin.astype(x.dtype)
  449. # Copied from transformers.models.llama.modeling_llama.rotate_half
  450. def rotate_half(x):
  451. """Rotates half the hidden dims of the input."""
  452. x1 = x[..., : x.shape[-1] // 2]
  453. x2 = x[..., x.shape[-1] // 2 :]
  454. return paddle.concat([-x2, x1], axis=-1)
  455. def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
  456. """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
  457. Explanation:
  458. Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
  459. sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
  460. vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
  461. Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
  462. For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
  463. height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
  464. difference with modern LLMs.
  465. Args:
  466. q (`paddle.Tensor`): The query tensor.
  467. k (`paddle.Tensor`): The key tensor.
  468. cos (`paddle.Tensor`): The cosine part of the rotary embedding.
  469. sin (`paddle.Tensor`): The sine part of the rotary embedding.
  470. position_ids (`paddle.Tensor`):
  471. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  472. used to pass offsetted position ids when working with a KV-cache.
  473. mrope_section(`List(int)`):
  474. Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
  475. unsqueeze_dim (`int`, *optional*, defaults to 1):
  476. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  477. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  478. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  479. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  480. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  481. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  482. Returns:
  483. `tuple(paddle.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  484. """
  485. mrope_section = mrope_section * 2
  486. cos = paddle.concat(
  487. x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1
  488. ).unsqueeze(axis=unsqueeze_dim)
  489. sin = paddle.concat(
  490. x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1
  491. ).unsqueeze(axis=unsqueeze_dim)
  492. q_embed = (q * cos) + (rotate_half(q) * sin)
  493. k_embed = (k * cos) + (rotate_half(k) * sin)
  494. return q_embed, k_embed
  495. def apply_rotary_pos_emb_vision(
  496. tensor: paddle.Tensor, freqs: paddle.Tensor
  497. ) -> paddle.Tensor:
  498. orig_dtype = tensor.dtype
  499. with paddle.amp.auto_cast(False):
  500. tensor = tensor.astype(dtype="float32")
  501. cos = freqs.cos()
  502. sin = freqs.sin()
  503. cos = (
  504. cos.unsqueeze(1)
  505. .tile(repeat_times=[1, 1, 2])
  506. .unsqueeze(0)
  507. .astype(dtype="float32")
  508. )
  509. sin = (
  510. sin.unsqueeze(1)
  511. .tile(repeat_times=[1, 1, 2])
  512. .unsqueeze(0)
  513. .astype(dtype="float32")
  514. )
  515. output = tensor * cos + rotate_half(tensor) * sin
  516. output = paddle.cast(output, orig_dtype)
  517. return output
  518. class VisionRotaryEmbedding(nn.Layer):
  519. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  520. super().__init__()
  521. self.inv_freq = 1.0 / theta ** (
  522. paddle.arange(start=0, end=dim, step=2, dtype="float32") / dim
  523. )
  524. def forward(self, seqlen: int) -> paddle.Tensor:
  525. seq = paddle.arange(seqlen).cast(self.inv_freq.dtype)
  526. freqs = paddle.outer(x=seq, y=self.inv_freq)
  527. return freqs
  528. class PatchEmbed(nn.Layer):
  529. def __init__(
  530. self,
  531. patch_size: int = 14,
  532. temporal_patch_size: int = 2,
  533. in_channels: int = 3,
  534. embed_dim: int = 1152,
  535. ) -> None:
  536. super().__init__()
  537. self.patch_size = patch_size
  538. self.temporal_patch_size = temporal_patch_size
  539. self.in_channels = in_channels
  540. self.embed_dim = embed_dim
  541. kernel_size = [temporal_patch_size, patch_size, patch_size]
  542. self.proj = nn.Conv3D(
  543. in_channels,
  544. embed_dim,
  545. kernel_size=kernel_size,
  546. stride=kernel_size,
  547. bias_attr=False,
  548. )
  549. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  550. target_dtype = self.proj.weight.dtype
  551. hidden_states = hidden_states.reshape(
  552. [
  553. -1,
  554. self.in_channels,
  555. self.temporal_patch_size,
  556. self.patch_size,
  557. self.patch_size,
  558. ]
  559. )
  560. # NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to'.
  561. # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim])
  562. # hidden_states = paddle.cast(hidden_states, dtype=target_dtype)
  563. hidden_states = self.proj(
  564. paddle.cast(hidden_states, dtype=target_dtype)
  565. ).reshape([-1, self.embed_dim])
  566. return hidden_states
  567. class PatchMerger(nn.Layer):
  568. def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
  569. super().__init__()
  570. self.hidden_size = context_dim * (spatial_merge_size**2)
  571. self.ln_q = nn.LayerNorm(context_dim, epsilon=1e-6)
  572. self.mlp = nn.Sequential(
  573. nn.Linear(self.hidden_size, self.hidden_size),
  574. nn.GELU(),
  575. nn.Linear(self.hidden_size, dim),
  576. )
  577. def forward(self, x: paddle.Tensor) -> paddle.Tensor:
  578. x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size]))
  579. return x
  580. class VisionMlp(nn.Layer):
  581. def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
  582. super().__init__()
  583. self.fc1 = nn.Linear(dim, hidden_dim)
  584. self.act = ACT2FN[hidden_act]
  585. self.fc2 = nn.Linear(hidden_dim, dim)
  586. def forward(self, x) -> paddle.Tensor:
  587. return self.fc2(self.act(self.fc1(x)))
  588. class VisionAttention(nn.Layer):
  589. def __init__(self, dim: int, num_heads: int = 16) -> None:
  590. super().__init__()
  591. self.num_heads = num_heads
  592. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  593. self.proj = nn.Linear(dim, dim)
  594. self.head_dim = dim // num_heads # must added
  595. def forward(
  596. self,
  597. hidden_states: paddle.Tensor,
  598. cu_seqlens: paddle.Tensor,
  599. rotary_pos_emb: paddle.Tensor = None,
  600. ) -> paddle.Tensor:
  601. seq_length = hidden_states.shape[0]
  602. q, k, v = (
  603. self.qkv(hidden_states)
  604. .reshape([seq_length, 3, self.num_heads, -1])
  605. .transpose([1, 0, 2, 3])
  606. .unbind(0)
  607. )
  608. q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
  609. k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
  610. attention_mask = paddle.zeros([1, seq_length, seq_length], dtype="bool")
  611. for i in range(1, len(cu_seqlens)):
  612. attention_mask[
  613. ...,
  614. cu_seqlens[i - 1] : cu_seqlens[i],
  615. cu_seqlens[i - 1] : cu_seqlens[i],
  616. ] = True
  617. zero = paddle.zeros(attention_mask.shape, dtype=hidden_states.dtype)
  618. neg_inf = paddle.full_like(
  619. attention_mask,
  620. paddle.finfo(hidden_states.dtype).min,
  621. dtype=hidden_states.dtype,
  622. )
  623. attention_mask = paddle.where(attention_mask, zero, neg_inf)
  624. q = q.transpose([1, 0, 2])
  625. k = k.transpose([1, 0, 2])
  626. v = v.transpose([1, 0, 2])
  627. attn_weights = paddle.matmul(q, k.transpose([0, 2, 1])) / math.sqrt(
  628. self.head_dim
  629. )
  630. attn_weights = attn_weights + attention_mask
  631. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  632. attn_output = paddle.matmul(attn_weights, v)
  633. attn_output = attn_output.transpose([1, 0, 2])
  634. attn_output = attn_output.reshape([seq_length, -1])
  635. attn_output = self.proj(attn_output)
  636. return attn_output
  637. class VisionFlashAttention2(nn.Layer):
  638. def __init__(self, dim: int, num_heads: int = 16) -> None:
  639. super().__init__()
  640. self.num_heads = num_heads
  641. self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
  642. self.proj = nn.Linear(dim, dim)
  643. self.head_dim = dim // num_heads # must added
  644. def forward(
  645. self,
  646. hidden_states: paddle.Tensor,
  647. cu_seqlens: paddle.Tensor,
  648. rotary_pos_emb: paddle.Tensor = None,
  649. ) -> paddle.Tensor:
  650. seq_length = tuple(hidden_states.shape)[0]
  651. qkv = (
  652. self.qkv(hidden_states)
  653. .reshape([seq_length, 3, self.num_heads, -1])
  654. .transpose(perm=[1, 0, 2, 3])
  655. )
  656. q, k, v = qkv.unbind(axis=0)
  657. q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  658. axis=0
  659. )
  660. k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(
  661. axis=0
  662. )
  663. if _IS_NPU:
  664. attn_output = paddle.nn.functional.flash_attention_npu(
  665. q.astype("bfloat16"),
  666. k.astype("bfloat16"),
  667. v.astype("bfloat16"),
  668. is_varlen=True,
  669. batch_size=1,
  670. seq_length=seq_length,
  671. ).reshape([seq_length, -1])
  672. else:
  673. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  674. softmax_scale = self.head_dim**-0.5
  675. attn_output = (
  676. flash_attn_varlen_func(
  677. q.astype("bfloat16"),
  678. k.astype("bfloat16"),
  679. v.astype("bfloat16"),
  680. cu_seqlens,
  681. cu_seqlens,
  682. max_seqlen,
  683. max_seqlen,
  684. scale=softmax_scale,
  685. )[0]
  686. .squeeze(0)
  687. .reshape([seq_length, -1])
  688. )
  689. if self.proj.weight.dtype == paddle.bfloat16:
  690. attn_output = attn_output.astype(paddle.bfloat16)
  691. elif self.proj.weight.dtype == paddle.float16:
  692. attn_output = attn_output.astype(paddle.float16)
  693. elif self.proj.weight.dtype == paddle.float32:
  694. attn_output = attn_output.astype(paddle.float32)
  695. attn_output = self.proj(attn_output)
  696. return attn_output
  697. def create_attention_module(config, module_type, layer_idx=None):
  698. if flash_attn_func is not None:
  699. if module_type == "qwen2vl":
  700. return Qwen2VLFlashAttention2(config, layer_idx)
  701. elif module_type == "vision":
  702. return VisionFlashAttention2(config.embed_dim, num_heads=config.num_heads)
  703. else:
  704. logging.warning_once(
  705. f"Warning: Flash Attention2 is not available for {module_type}, fallback to normal attention."
  706. )
  707. if module_type == "qwen2vl":
  708. return Qwen2VLAttention(config, layer_idx)
  709. elif module_type == "vision":
  710. return VisionAttention(config.embed_dim, num_heads=config.num_heads)
  711. class Qwen2VLVisionBlock(nn.Layer):
  712. def __init__(self, config, attn_implementation: str = "flash_attention_2") -> None:
  713. super().__init__()
  714. self.norm1 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  715. self.norm2 = nn.LayerNorm(config.embed_dim, epsilon=1e-6)
  716. mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
  717. self.attn = create_attention_module(config, "vision")
  718. self.mlp = VisionMlp(
  719. dim=config.embed_dim,
  720. hidden_dim=mlp_hidden_dim,
  721. hidden_act=config.hidden_act,
  722. )
  723. def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
  724. hidden_states = hidden_states + self.attn(
  725. self.norm1(hidden_states),
  726. cu_seqlens=cu_seqlens,
  727. rotary_pos_emb=rotary_pos_emb,
  728. )
  729. hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
  730. return hidden_states
  731. def _prepare_4d_causal_attention_mask_with_cache_position(
  732. attention_mask: paddle.Tensor,
  733. sequence_length: int,
  734. target_length: int,
  735. dtype: paddle.dtype,
  736. min_dtype: float,
  737. cache_position: paddle.Tensor,
  738. batch_size: int,
  739. ):
  740. """
  741. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  742. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  743. Args:
  744. attention_mask (`paddle.Tensor`):
  745. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  746. sequence_length (`int`):
  747. The sequence length being processed.
  748. target_length (`int`):
  749. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  750. dtype (`paddle.dtype`):
  751. The dtype to use for the 4D attention mask.
  752. min_dtype (`float`):
  753. The minimum value representable with the dtype `dtype`.
  754. cache_position (`paddle.Tensor`):
  755. Indices depicting the position of the input sequence tokens in the sequence.
  756. batch_size (`paddle.Tensor`):
  757. Batch size.
  758. """
  759. if attention_mask is not None and attention_mask.dim() == 4:
  760. causal_mask = attention_mask
  761. else:
  762. causal_mask = paddle.full(
  763. [sequence_length, target_length], fill_value=min_dtype, dtype=dtype
  764. )
  765. if sequence_length != 1:
  766. causal_mask = paddle.triu(x=causal_mask, diagonal=1)
  767. causal_mask *= paddle.arange(target_length) > cache_position.reshape([-1, 1])
  768. causal_mask = causal_mask[None, None, :, :].expand(
  769. shape=[batch_size, 1, -1, -1]
  770. )
  771. if attention_mask is not None:
  772. causal_mask = causal_mask.clone()
  773. mask_length = tuple(attention_mask.shape)[-1]
  774. padding_mask = (
  775. causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  776. )
  777. padding_mask = padding_mask == 0
  778. causal_mask[:, :, :, :mask_length] = causal_mask[
  779. :, :, :, :mask_length
  780. ].masked_fill(mask=padding_mask, value=min_dtype)
  781. return causal_mask
  782. class Qwen2RMSNorm(nn.Layer):
  783. def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6):
  784. """
  785. Qwen2RMSNorm is equivalent to T5LayerNorm
  786. """
  787. super().__init__()
  788. self.weight = paddle.create_parameter(
  789. shape=[hidden_size],
  790. dtype=paddle.get_default_dtype(),
  791. default_initializer=nn.initializer.Constant(1.0),
  792. )
  793. self.variance_epsilon = eps
  794. def forward(self, hidden_states):
  795. if paddle.in_dynamic_mode():
  796. with paddle.amp.auto_cast(False):
  797. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  798. hidden_states = (
  799. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  800. )
  801. else:
  802. variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
  803. hidden_states = (
  804. paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
  805. )
  806. if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
  807. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  808. return hidden_states * self.weight
  809. class Qwen2MLP(nn.Layer):
  810. def __init__(self, config):
  811. super().__init__()
  812. self.hidden_size = config.hidden_size
  813. self.intermediate_size = config.intermediate_size
  814. self.fuse_attention_ffn = config.fuse_attention_ffn
  815. self.tensor_parallel_degree = config.tensor_parallel_degree
  816. if config.tensor_parallel_degree > 1:
  817. self.gate_proj = ColumnParallelLinear(
  818. self.hidden_size,
  819. self.intermediate_size,
  820. gather_output=False,
  821. has_bias=False,
  822. )
  823. self.up_proj = ColumnParallelLinear(
  824. self.hidden_size,
  825. self.intermediate_size,
  826. gather_output=False,
  827. has_bias=False,
  828. )
  829. self.down_proj = RowParallelLinear(
  830. self.intermediate_size,
  831. self.hidden_size,
  832. input_is_parallel=True,
  833. has_bias=False,
  834. )
  835. else:
  836. self.gate_proj = Linear(
  837. self.hidden_size, self.intermediate_size, bias_attr=False
  838. ) # w1
  839. self.up_proj = Linear(
  840. self.hidden_size, self.intermediate_size, bias_attr=False
  841. ) # w3
  842. self.down_proj = Linear(
  843. self.intermediate_size, self.hidden_size, bias_attr=False
  844. ) # w2
  845. self.act_fn = ACT2FN[config.hidden_act]
  846. self.fuse_swiglu = False
  847. def forward(self, x):
  848. x, y = self.gate_proj(x), self.up_proj(x)
  849. if self.fuse_swiglu:
  850. x = self.act_fn(x, y)
  851. else:
  852. x = self.act_fn(x) * y
  853. return self.down_proj(x)
  854. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  855. def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
  856. """
  857. This is the equivalent of paddle.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch,
  858. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  859. """
  860. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  861. if n_rep == 1:
  862. return hidden_states
  863. hidden_states = hidden_states[:, :, None, :, :].expand(
  864. [batch, num_key_value_heads, n_rep, slen, head_dim]
  865. )
  866. return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
  867. class Qwen2VLAttention(nn.Layer):
  868. """
  869. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  870. and "Generating Long Sequences with Sparse Transformers".
  871. """
  872. def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
  873. super().__init__()
  874. self.config = config
  875. self.layer_idx = layer_idx
  876. if layer_idx is None:
  877. logging.warning_once(
  878. f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
  879. "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  880. "when creating this class."
  881. )
  882. self.hidden_size = config.hidden_size
  883. self.num_heads = config.num_attention_heads
  884. self.head_dim = self.hidden_size // self.num_heads
  885. self.num_key_value_heads = config.num_key_value_heads
  886. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  887. self.max_position_embeddings = config.max_position_embeddings
  888. self.rope_theta = config.rope_theta
  889. self.is_causal = True
  890. self.attention_dropout = config.attention_dropout
  891. self.rope_scaling = config.rope_scaling
  892. # self.sequence_parallel = config.sequence_parallel
  893. if config.tensor_parallel_degree > 1:
  894. assert (
  895. self.num_heads % config.tensor_parallel_degree == 0
  896. ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  897. self.num_heads = self.num_heads // config.tensor_parallel_degree
  898. assert (
  899. self.num_key_value_heads % config.tensor_parallel_degree == 0
  900. ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
  901. self.num_key_value_heads = (
  902. self.num_key_value_heads // config.tensor_parallel_degree
  903. )
  904. if config.tensor_parallel_degree > 1:
  905. self.q_proj = ColumnParallelLinear(
  906. self.hidden_size, self.hidden_size, has_bias=True, gather_output=False
  907. )
  908. self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  909. self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False) # fmt:skip
  910. self.o_proj = RowParallelLinear(
  911. self.hidden_size,
  912. self.hidden_size,
  913. has_bias=False,
  914. input_is_parallel=True,
  915. )
  916. else:
  917. self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
  918. self.k_proj = Linear(
  919. self.hidden_size,
  920. self.config.num_key_value_heads * self.head_dim,
  921. bias_attr=True,
  922. )
  923. self.v_proj = Linear(
  924. self.hidden_size,
  925. self.config.num_key_value_heads * self.head_dim,
  926. bias_attr=True,
  927. )
  928. self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
  929. self.rotary_emb = Qwen2VLRotaryEmbedding(
  930. self.head_dim,
  931. max_position_embeddings=self.max_position_embeddings,
  932. base=self.rope_theta,
  933. )
  934. def forward(
  935. self,
  936. hidden_states: paddle.Tensor,
  937. attention_mask: Optional[paddle.Tensor] = None,
  938. position_ids: Optional[paddle.Tensor] = None,
  939. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  940. output_attentions: bool = False,
  941. use_cache: bool = False, # default true
  942. cache_position: Optional[paddle.Tensor] = None,
  943. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  944. bsz, q_len, _ = hidden_states.shape
  945. try:
  946. query_states = self.q_proj(hidden_states)
  947. key_states = self.k_proj(hidden_states)
  948. value_states = self.v_proj(hidden_states)
  949. except:
  950. hidden_states = hidden_states.astype(self.config.dtype)
  951. query_states = self.q_proj(hidden_states)
  952. key_states = self.k_proj(hidden_states)
  953. value_states = self.v_proj(hidden_states)
  954. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  955. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  956. query_states = query_states.reshape(shape=target_query_shape)
  957. key_states = key_states.reshape(shape=target_key_value_shape)
  958. value_states = value_states.reshape(shape=target_key_value_shape)
  959. new_perm = [0, 2, 1, 3]
  960. query_states = query_states.transpose(new_perm)
  961. key_states = key_states.transpose(new_perm)
  962. value_states = value_states.transpose(new_perm)
  963. kv_seq_len = key_states.shape[-2]
  964. if past_key_value is not None:
  965. kv_seq_len += cache_position[0] + 1
  966. cos, sin = self.rotary_emb(value_states, position_ids)
  967. query_states, key_states = apply_multimodal_rotary_pos_emb(
  968. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  969. )
  970. if past_key_value is not None:
  971. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  972. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  973. past_key_value = (key_states, value_states) if use_cache else None
  974. # repeat k/v heads if n_kv_heads < n_heads
  975. key_states = repeat_kv(key_states, self.num_key_value_groups)
  976. value_states = repeat_kv(value_states, self.num_key_value_groups)
  977. query_states = query_states.astype("float32")
  978. key_states = key_states.astype("float32")
  979. value_states = value_states.astype("float32")
  980. attn_weights = paddle.matmul(
  981. query_states, key_states.transpose([0, 1, 3, 2])
  982. ) / math.sqrt(self.head_dim)
  983. if attention_mask is not None:
  984. attn_weights = attn_weights + attention_mask
  985. attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype="float32")
  986. attn_weights = nn.functional.dropout(
  987. x=attn_weights, p=self.attention_dropout, training=self.training
  988. )
  989. attn_output = paddle.matmul(
  990. attn_weights.cast(self.config.dtype), value_states.cast(self.config.dtype)
  991. )
  992. if attn_output.shape != [bsz, self.num_heads, q_len, self.head_dim]:
  993. raise ValueError(
  994. f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
  995. f" {attn_output.shape}"
  996. )
  997. attn_output = attn_output.transpose([0, 2, 1, 3])
  998. attn_output = attn_output.reshape([bsz, q_len, -1])
  999. if self.o_proj.weight.dtype == paddle.bfloat16:
  1000. attn_output = attn_output.astype(paddle.bfloat16)
  1001. elif self.o_proj.weight.dtype == paddle.float16:
  1002. attn_output = attn_output.astype(paddle.float16)
  1003. elif self.o_proj.weight.dtype == paddle.float32:
  1004. attn_output = attn_output.astype(paddle.float32)
  1005. attn_output = self.o_proj(attn_output)
  1006. if not output_attentions:
  1007. attn_weights = None
  1008. return attn_output, attn_weights, past_key_value
  1009. class Qwen2VLFlashAttention2(Qwen2VLAttention):
  1010. """
  1011. Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
  1012. as the weights of the module stays untouched. The only required change would be on the forward pass
  1013. where it needs to correctly call the public API of flash attention and deal with padding tokens
  1014. in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
  1015. config.max_window_layers layers.
  1016. """
  1017. def __init__(self, *args, **kwargs):
  1018. super().__init__(*args, **kwargs)
  1019. def forward(
  1020. self,
  1021. hidden_states: paddle.Tensor,
  1022. attention_mask: Optional[paddle.Tensor] = None,
  1023. position_ids: Optional[paddle.Tensor] = None,
  1024. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1025. output_attentions: bool = False,
  1026. use_cache: bool = False, # default true
  1027. cache_position: Optional[paddle.Tensor] = None,
  1028. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
  1029. bsz, q_len, _ = tuple(hidden_states.shape)
  1030. try:
  1031. query_states = self.q_proj(hidden_states)
  1032. key_states = self.k_proj(hidden_states)
  1033. value_states = self.v_proj(hidden_states)
  1034. except:
  1035. hidden_states = hidden_states.astype("bfloat16")
  1036. query_states = self.q_proj(hidden_states)
  1037. key_states = self.k_proj(hidden_states)
  1038. value_states = self.v_proj(hidden_states)
  1039. target_query_shape = [0, 0, self.num_heads, self.head_dim]
  1040. target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
  1041. query_states = query_states.reshape(shape=target_query_shape)
  1042. key_states = key_states.reshape(shape=target_key_value_shape)
  1043. value_states = value_states.reshape(shape=target_key_value_shape)
  1044. new_perm = [0, 2, 1, 3]
  1045. query_states = query_states.transpose(new_perm)
  1046. key_states = key_states.transpose(new_perm)
  1047. value_states = value_states.transpose(new_perm)
  1048. kv_seq_len = key_states.shape[-2]
  1049. if past_key_value is not None:
  1050. kv_seq_len += cache_position[0] + 1
  1051. # Because the input can be padded, the absolute sequence length depends on the max position id.
  1052. cos, sin = self.rotary_emb(value_states, position_ids)
  1053. query_states, key_states = apply_multimodal_rotary_pos_emb(
  1054. query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
  1055. )
  1056. if past_key_value is not None:
  1057. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  1058. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  1059. past_key_value = (key_states, value_states) if use_cache else None
  1060. # repeat k/v heads if n_kv_heads < n_heads
  1061. key_states = repeat_kv(key_states, self.num_key_value_groups)
  1062. value_states = repeat_kv(value_states, self.num_key_value_groups)
  1063. # Reashape to the expected shape for Flash Attention
  1064. # [1, 3599, 12, 128]
  1065. query_states = query_states.transpose(perm=[0, 2, 1, 3])
  1066. key_states = key_states.transpose(perm=[0, 2, 1, 3])
  1067. value_states = value_states.transpose(perm=[0, 2, 1, 3])
  1068. attn_output = self._flash_attention_forward(
  1069. query_states, key_states, value_states, attention_mask, q_len
  1070. )
  1071. attn_output = attn_output.reshape([bsz, q_len, -1])
  1072. attn_output = self.o_proj(attn_output)
  1073. if not output_attentions:
  1074. attn_weights = None
  1075. return attn_output, attn_weights, past_key_value
  1076. def _flash_attention_forward(
  1077. self,
  1078. query_states,
  1079. key_states,
  1080. value_states,
  1081. attention_mask,
  1082. query_length,
  1083. dropout=0.0,
  1084. softmax_scale=None,
  1085. ):
  1086. """
  1087. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  1088. first unpad the input, then computes the attention scores and pad the final attention scores.
  1089. Args:
  1090. query_states (`paddle.Tensor`):
  1091. Input query states to be passed to Flash Attention API
  1092. key_states (`paddle.Tensor`):
  1093. Input key states to be passed to Flash Attention API
  1094. value_states (`paddle.Tensor`):
  1095. Input value states to be passed to Flash Attention API
  1096. attention_mask (`paddle.Tensor`):
  1097. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  1098. position of padding tokens and 1 for the position of non-padding tokens.
  1099. dropout (`int`, *optional*):
  1100. Attention dropout
  1101. softmax_scale (`float`, *optional*):
  1102. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  1103. """
  1104. # Contains at least one padding token in the sequence
  1105. causal = self.is_causal and query_length != 1
  1106. if _IS_NPU:
  1107. if attention_mask is not None:
  1108. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1109. query_states,
  1110. key_states,
  1111. value_states,
  1112. attn_mask=attention_mask,
  1113. dropout=dropout,
  1114. causal=causal,
  1115. is_varlen=True,
  1116. )
  1117. else:
  1118. dtype = query_states.dtype
  1119. attn_output = paddle.nn.functional.flash_attention_npu( # TODO: flash_attn_unpadded
  1120. query_states.astype("bfloat16"),
  1121. key_states.astype("bfloat16"),
  1122. value_states.astype("bfloat16"),
  1123. attn_mask=attention_mask,
  1124. dropout=dropout,
  1125. causal=causal,
  1126. )
  1127. attn_output = attn_output.astype(dtype)
  1128. else:
  1129. head_dim = query_states.shape[-1]
  1130. softmax_scale = head_dim**-0.5 # TODO: 需要手动加上
  1131. if attention_mask is not None:
  1132. batch_size = query_states.shape[0]
  1133. (
  1134. query_states,
  1135. key_states,
  1136. value_states,
  1137. indices_q,
  1138. cu_seq_lens,
  1139. max_seq_lens,
  1140. ) = self._unpad_input(
  1141. query_states, key_states, value_states, attention_mask, query_length
  1142. )
  1143. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  1144. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  1145. attn_output_unpad = flash_attn_varlen_func(
  1146. query_states,
  1147. key_states,
  1148. value_states,
  1149. cu_seqlens_q=cu_seqlens_q,
  1150. cu_seqlens_k=cu_seqlens_k,
  1151. max_seqlen_q=max_seqlen_in_batch_q,
  1152. max_seqlen_k=max_seqlen_in_batch_k,
  1153. scale=softmax_scale, # not softmax_scale=
  1154. dropout=dropout,
  1155. causal=causal,
  1156. )[0]
  1157. attn_output = pad_input(
  1158. attn_output_unpad, indices_q, batch_size, query_length
  1159. )
  1160. else:
  1161. attn_output = flash_attn_func(
  1162. query_states,
  1163. key_states,
  1164. value_states,
  1165. dropout,
  1166. causal=causal,
  1167. )[0]
  1168. return attn_output
  1169. def _unpad_input(
  1170. self, query_layer, key_layer, value_layer, attention_mask, query_length
  1171. ):
  1172. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  1173. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  1174. # TODO:cuda error
  1175. key_layer = index_first_axis(
  1176. key_layer.reshape([batch_size * kv_seq_len, num_key_value_heads, head_dim]),
  1177. indices_k,
  1178. )
  1179. value_layer = index_first_axis(
  1180. value_layer.reshape(
  1181. [batch_size * kv_seq_len, num_key_value_heads, head_dim]
  1182. ),
  1183. indices_k,
  1184. )
  1185. if query_length == kv_seq_len:
  1186. query_layer = index_first_axis(
  1187. query_layer.reshape(
  1188. [batch_size * kv_seq_len, self.num_heads, head_dim]
  1189. ),
  1190. indices_k,
  1191. )
  1192. cu_seqlens_q = cu_seqlens_k
  1193. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  1194. indices_q = indices_k
  1195. elif query_length == 1:
  1196. max_seqlen_in_batch_q = 1
  1197. cu_seqlens_q = paddle.arange(
  1198. batch_size + 1, dtype=paddle.int32
  1199. ) # There is a memcpy here, that is very bad.
  1200. indices_q = cu_seqlens_q[:-1]
  1201. query_layer = query_layer.squeeze(1)
  1202. else:
  1203. # The -q_len: slice assumes left padding.
  1204. attention_mask = attention_mask[:, -query_length:]
  1205. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  1206. query_layer, attention_mask
  1207. )
  1208. return (
  1209. query_layer,
  1210. key_layer,
  1211. value_layer,
  1212. indices_q.to(paddle.int64),
  1213. (cu_seqlens_q, cu_seqlens_k),
  1214. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  1215. )
  1216. class Qwen2VLDecoderLayer(nn.Layer):
  1217. def __init__(self, config: Qwen2VLConfig, layer_idx: int):
  1218. super().__init__()
  1219. self.hidden_size = config.hidden_size
  1220. # use_sliding_window false
  1221. if (
  1222. config.use_sliding_window
  1223. and config.attn_implementation != "flash_attention_2"
  1224. ):
  1225. logging.warning_once(
  1226. f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
  1227. "unexpected results may be encountered."
  1228. )
  1229. self.self_attn = create_attention_module(config, "qwen2vl", layer_idx=layer_idx)
  1230. # self.self_attn = Qwen2VLAttention(config, layer_idx)
  1231. self.mlp = Qwen2MLP(config)
  1232. self.input_layernorm = Qwen2RMSNorm(
  1233. config, config.hidden_size, eps=config.rms_norm_eps
  1234. )
  1235. self.post_attention_layernorm = Qwen2RMSNorm(
  1236. config, config.hidden_size, eps=config.rms_norm_eps
  1237. )
  1238. def forward(
  1239. self,
  1240. hidden_states: paddle.Tensor,
  1241. attention_mask: Optional[paddle.Tensor] = None,
  1242. position_ids: Optional[paddle.Tensor] = None,
  1243. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  1244. output_attentions: Optional[bool] = False,
  1245. use_cache: Optional[bool] = False,
  1246. cache_position: Optional[paddle.Tensor] = None,
  1247. **kwargs,
  1248. ):
  1249. """
  1250. Args:
  1251. hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  1252. attention_mask (`paddle.Tensor`, *optional*): attention mask of size
  1253. `(batch, sequence_length)` where padding elements are indicated by 0.
  1254. output_attentions (`bool`, *optional*):
  1255. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1256. returned tensors for more detail.
  1257. use_cache (`bool`, *optional*):
  1258. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1259. (see `past_key_values`).
  1260. past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
  1261. cache_position (`paddle.Tensor` of shape `(sequence_length)`, *optional*):
  1262. Indices depicting the position of the input sequence tokens in the sequence.
  1263. kwargs (`dict`, *optional*):
  1264. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  1265. into the model
  1266. """
  1267. residual = hidden_states
  1268. hidden_states = self.input_layernorm(hidden_states)
  1269. # Self Attention
  1270. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  1271. hidden_states=hidden_states,
  1272. attention_mask=attention_mask,
  1273. position_ids=position_ids,
  1274. past_key_value=past_key_value,
  1275. output_attentions=output_attentions,
  1276. use_cache=use_cache,
  1277. cache_position=cache_position,
  1278. )
  1279. hidden_states = residual + hidden_states
  1280. # Fully Connected
  1281. residual = hidden_states
  1282. hidden_states = self.post_attention_layernorm(hidden_states)
  1283. hidden_states = self.mlp(hidden_states)
  1284. hidden_states = residual + hidden_states
  1285. outputs = (hidden_states,)
  1286. if output_attentions:
  1287. outputs += (self_attn_weights,)
  1288. if use_cache:
  1289. outputs += (present_key_value,)
  1290. return outputs
  1291. class Qwen2VLPreTrainedModel(PretrainedModel):
  1292. config_class = Qwen2VLConfig
  1293. base_model_prefix = "model"
  1294. _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
  1295. _skip_keys_device_placement = "past_key_values"
  1296. def _init_weights(self, layer):
  1297. std = 0.2
  1298. if isinstance(layer, (nn.Linear, nn.Conv3D)):
  1299. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1300. if layer.bias is not None:
  1301. nn.initializer.Constant(0.0)(layer.bias)
  1302. elif isinstance(layer, nn.Embedding):
  1303. nn.initializer.Normal(mean=0.0, std=std)(layer.weight)
  1304. if layer._padding_idx is not None:
  1305. with paddle.no_grad():
  1306. layer.weight[layer._padding_idx] = 0.0
  1307. class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
  1308. config_class = Qwen2VLVisionConfig
  1309. _no_split_modules = ["Qwen2VLVisionBlock"]
  1310. def __init__(self, config) -> None:
  1311. super().__init__(config)
  1312. self.spatial_merge_size = config.spatial_merge_size
  1313. self.patch_embed = PatchEmbed(
  1314. patch_size=config.patch_size,
  1315. temporal_patch_size=config.temporal_patch_size,
  1316. in_channels=config.in_channels,
  1317. embed_dim=config.embed_dim,
  1318. )
  1319. head_dim = config.embed_dim // config.num_heads
  1320. self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
  1321. self.blocks = nn.LayerList(
  1322. [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
  1323. )
  1324. self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
  1325. self.enable_recompute = False
  1326. def get_dtype(self) -> paddle.dtype:
  1327. return self.blocks[0].mlp.fc2.weight.dtype
  1328. def rot_pos_emb(self, grid_thw):
  1329. pos_ids = []
  1330. for t, h, w in grid_thw:
  1331. hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w])
  1332. hpos_ids = hpos_ids.reshape(
  1333. [
  1334. h // self.spatial_merge_size,
  1335. self.spatial_merge_size,
  1336. w // self.spatial_merge_size,
  1337. self.spatial_merge_size,
  1338. ]
  1339. )
  1340. hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3])
  1341. hpos_ids = hpos_ids.flatten()
  1342. wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1])
  1343. wpos_ids = wpos_ids.reshape(
  1344. [
  1345. h // self.spatial_merge_size,
  1346. self.spatial_merge_size,
  1347. w // self.spatial_merge_size,
  1348. self.spatial_merge_size,
  1349. ]
  1350. )
  1351. wpos_ids = wpos_ids.transpose([0, 2, 1, 3])
  1352. wpos_ids = wpos_ids.flatten()
  1353. pos_ids.append(
  1354. paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile(repeat_times=[t, 1])
  1355. )
  1356. pos_ids = paddle.concat(x=pos_ids, axis=0)
  1357. max_grid_size = grid_thw[:, 1:].max()
  1358. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  1359. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1)
  1360. return rotary_pos_emb
  1361. @paddle.jit.not_to_static
  1362. def recompute_training_full(
  1363. self,
  1364. layer_module: nn.Layer,
  1365. hidden_states: paddle.Tensor,
  1366. cu_seqlens_now: paddle.Tensor,
  1367. rotary_pos_emb: paddle.Tensor,
  1368. ):
  1369. def create_custom_forward(module):
  1370. def custom_forward(*inputs):
  1371. return module(*inputs)
  1372. return custom_forward
  1373. hidden_states = recompute(
  1374. create_custom_forward(layer_module),
  1375. hidden_states,
  1376. cu_seqlens_now,
  1377. rotary_pos_emb,
  1378. # use_reentrant=self.config.recompute_use_reentrant,
  1379. )
  1380. return hidden_states
  1381. def forward(
  1382. self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor
  1383. ) -> paddle.Tensor:
  1384. # breakpoint()
  1385. hidden_states = self.patch_embed(hidden_states)
  1386. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  1387. cu_seqlens = paddle.repeat_interleave(
  1388. grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
  1389. ).cumsum(axis=0, dtype="int32")
  1390. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  1391. for idx, blk in enumerate(self.blocks):
  1392. if self.enable_recompute and self.training:
  1393. hidden_states = self.recompute_training_full(
  1394. blk, hidden_states, cu_seqlens, rotary_pos_emb
  1395. )
  1396. else:
  1397. hidden_states = blk(
  1398. hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  1399. )
  1400. return self.merger(hidden_states)
  1401. class Qwen2VLModel(Qwen2VLPreTrainedModel):
  1402. def __init__(self, config: Qwen2VLConfig):
  1403. super().__init__(config)
  1404. self.padding_idx = config.pad_token_id
  1405. self.vocab_size = config.vocab_size
  1406. self.hidden_size = config.hidden_size
  1407. # Recompute defaults to False and is controlled by Trainer
  1408. if (
  1409. config.tensor_parallel_degree > 1
  1410. and config.vocab_size % config.tensor_parallel_degree == 0
  1411. ):
  1412. self.embed_tokens = mpu.VocabParallelEmbedding(
  1413. self.vocab_size,
  1414. self.hidden_size,
  1415. weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
  1416. )
  1417. else:
  1418. self.embed_tokens = nn.Embedding(
  1419. self.vocab_size,
  1420. self.hidden_size,
  1421. )
  1422. # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  1423. self.layers = nn.LayerList(
  1424. [
  1425. Qwen2VLDecoderLayer(config, layer_idx)
  1426. for layer_idx in range(config.num_hidden_layers)
  1427. ]
  1428. )
  1429. self.norm = Qwen2RMSNorm(config, config.hidden_size, eps=config.rms_norm_eps)
  1430. self.enamble_recompute = False
  1431. def get_input_embeddings(self):
  1432. return self.embed_tokens
  1433. def set_input_embeddings(self, value):
  1434. self.embed_tokens = value
  1435. @staticmethod
  1436. def _prepare_decoder_attention_mask(
  1437. attention_mask, input_shape, past_key_values_length, dtype
  1438. ):
  1439. if attention_mask is not None:
  1440. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1441. if len(attention_mask.shape) == 2:
  1442. expanded_attn_mask = _expand_2d_mask(
  1443. attention_mask, dtype, tgt_length=input_shape[-1]
  1444. )
  1445. # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
  1446. if input_shape[-1] > 1:
  1447. combined_attention_mask = _make_causal_mask(
  1448. input_shape,
  1449. past_key_values_length=past_key_values_length,
  1450. )
  1451. expanded_attn_mask = expanded_attn_mask & combined_attention_mask
  1452. # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
  1453. elif len(attention_mask.shape) == 3:
  1454. expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
  1455. # if attention_mask is already 4-D, do nothing
  1456. else:
  1457. expanded_attn_mask = attention_mask
  1458. else:
  1459. expanded_attn_mask = _make_causal_mask(
  1460. input_shape,
  1461. past_key_values_length=past_key_values_length,
  1462. )
  1463. # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
  1464. expanded_attn_mask = paddle.where(
  1465. expanded_attn_mask, 0.0, paddle.finfo(dtype).min
  1466. ).astype(dtype)
  1467. return expanded_attn_mask
  1468. @paddle.jit.not_to_static
  1469. def recompute_training_full(
  1470. self,
  1471. layer_module: nn.Layer,
  1472. hidden_states: paddle.Tensor,
  1473. attention_mask: paddle.Tensor,
  1474. position_ids: Optional[paddle.Tensor],
  1475. past_key_value: paddle.Tensor,
  1476. output_attentions: bool,
  1477. use_cache: bool,
  1478. cache_position: Optional[paddle.Tensor] = None,
  1479. ):
  1480. def create_custom_forward(module):
  1481. def custom_forward(*inputs):
  1482. return module(*inputs)
  1483. return custom_forward
  1484. hidden_states = recompute(
  1485. create_custom_forward(layer_module),
  1486. hidden_states,
  1487. attention_mask,
  1488. position_ids,
  1489. past_key_value,
  1490. output_attentions,
  1491. use_cache,
  1492. cache_position,
  1493. use_reentrant=self.config.recompute_use_reentrant,
  1494. )
  1495. return hidden_states
  1496. def forward(
  1497. self,
  1498. input_ids: paddle.Tensor = None,
  1499. attention_mask: Optional[paddle.Tensor] = None,
  1500. position_ids: Optional[paddle.Tensor] = None,
  1501. past_key_values: Optional[List[paddle.Tensor]] = None,
  1502. inputs_embeds: Optional[paddle.Tensor] = None,
  1503. use_cache: Optional[bool] = None,
  1504. output_attentions: Optional[bool] = None,
  1505. output_hidden_states: Optional[bool] = None,
  1506. return_dict: Optional[bool] = None,
  1507. cache_position: Optional[paddle.Tensor] = None,
  1508. ) -> Union[Tuple, BaseModelOutputWithPast]:
  1509. output_attentions = (
  1510. output_attentions
  1511. if output_attentions is not None
  1512. else self.config.output_attentions
  1513. )
  1514. output_hidden_states = (
  1515. output_hidden_states
  1516. if output_hidden_states is not None
  1517. else self.config.output_hidden_states
  1518. )
  1519. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1520. return_dict = (
  1521. return_dict if return_dict is not None else self.config.use_return_dict
  1522. )
  1523. if (input_ids is None) ^ (inputs_embeds is not None):
  1524. raise ValueError(
  1525. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  1526. )
  1527. elif input_ids is not None:
  1528. batch_size, seq_length = input_ids.shape
  1529. elif inputs_embeds is not None:
  1530. batch_size, seq_length, _ = inputs_embeds.shape
  1531. else:
  1532. raise ValueError(
  1533. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1534. )
  1535. if past_key_values is None:
  1536. past_key_values = tuple([None] * len(self.layers))
  1537. # NOTE: to make cache can be clear in-time
  1538. past_key_values = list(past_key_values)
  1539. seq_length_with_past = seq_length
  1540. cache_length = 0
  1541. if past_key_values[0] is not None:
  1542. cache_length = past_key_values[0][0].shape[2] # shape[1] in qwen2
  1543. seq_length_with_past += cache_length
  1544. if inputs_embeds is None:
  1545. inputs_embeds = self.embed_tokens(input_ids)
  1546. # embed positions
  1547. if attention_mask is None:
  1548. # [bs, seq_len]
  1549. attention_mask = paddle.ones(
  1550. (batch_size, seq_length_with_past), dtype=paddle.bool
  1551. )
  1552. if flash_attn_varlen_func:
  1553. causal_mask = attention_mask
  1554. else:
  1555. causal_mask = self._prepare_decoder_attention_mask(
  1556. attention_mask,
  1557. (batch_size, seq_length),
  1558. cache_length,
  1559. inputs_embeds.dtype,
  1560. ) # [bs, 1, seq_len, seq_len]
  1561. if cache_position is None:
  1562. past_seen_tokens = (
  1563. past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0
  1564. )
  1565. cache_position = paddle.arange(
  1566. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
  1567. )
  1568. if position_ids is None:
  1569. # the hard coded `3` is for temporal, height and width.
  1570. position_ids = cache_position.reshape([1, 1, -1]).expand(
  1571. [3, inputs_embeds.shape[0], -1]
  1572. )
  1573. hidden_states = inputs_embeds
  1574. # decoder layers
  1575. all_hidden_states = () if output_hidden_states else None
  1576. all_self_attns = () if output_attentions else None
  1577. next_decoder_cache = ()
  1578. for idx, (decoder_layer) in enumerate(self.layers):
  1579. if output_hidden_states:
  1580. all_hidden_states += (hidden_states,)
  1581. past_key_value = (
  1582. past_key_values[idx] if past_key_values is not None else None
  1583. )
  1584. if self.enamble_recompute and self.training:
  1585. layer_outputs = self.recompute_training_full(
  1586. decoder_layer,
  1587. hidden_states,
  1588. causal_mask,
  1589. position_ids,
  1590. past_key_value,
  1591. output_attentions,
  1592. use_cache,
  1593. cache_position,
  1594. )
  1595. else:
  1596. layer_outputs = decoder_layer(
  1597. hidden_states,
  1598. attention_mask=causal_mask,
  1599. position_ids=position_ids,
  1600. past_key_value=past_key_value,
  1601. output_attentions=output_attentions, # False
  1602. use_cache=use_cache, # True
  1603. cache_position=cache_position,
  1604. )
  1605. # NOTE: clear outdate cache after it has been used for memory saving
  1606. past_key_value = past_key_values[idx] = None
  1607. hidden_states = layer_outputs[0]
  1608. next_decoder_cache = (
  1609. next_decoder_cache + (layer_outputs[-1],) if use_cache else None
  1610. )
  1611. if output_attentions:
  1612. all_self_attns += (layer_outputs[1],)
  1613. hidden_states = self.norm(hidden_states)
  1614. # add hidden states from the last decoder layer
  1615. if output_hidden_states:
  1616. all_hidden_states += (hidden_states,)
  1617. next_cache = next_decoder_cache if use_cache else None
  1618. if not return_dict:
  1619. return tuple(
  1620. v
  1621. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
  1622. if v is not None
  1623. )
  1624. return BaseModelOutputWithPast(
  1625. last_hidden_state=hidden_states,
  1626. past_key_values=next_cache,
  1627. hidden_states=all_hidden_states,
  1628. attentions=all_self_attns,
  1629. )
  1630. class Qwen2LMHead(nn.Layer):
  1631. def __init__(self, config, embedding_weights=None, transpose_y=False):
  1632. super(Qwen2LMHead, self).__init__()
  1633. self.config = config
  1634. if (
  1635. config.tensor_parallel_degree > 1
  1636. and config.vocab_size % config.tensor_parallel_degree == 0
  1637. ):
  1638. vocab_size = config.vocab_size // config.tensor_parallel_degree
  1639. else:
  1640. vocab_size = config.vocab_size
  1641. self.transpose_y = transpose_y
  1642. if transpose_y:
  1643. # only for weight from embedding_weights
  1644. if embedding_weights is not None:
  1645. self.weight = embedding_weights
  1646. else:
  1647. self.weight = self.create_parameter(
  1648. shape=[vocab_size, config.hidden_size],
  1649. dtype=paddle.get_default_dtype(),
  1650. )
  1651. else:
  1652. if vocab_size != config.vocab_size:
  1653. with get_rng_state_tracker().rng_state():
  1654. self.weight = self.create_parameter(
  1655. shape=[config.hidden_size, vocab_size],
  1656. dtype=paddle.get_default_dtype(),
  1657. )
  1658. else:
  1659. self.weight = self.create_parameter(
  1660. shape=[config.hidden_size, vocab_size],
  1661. dtype=paddle.get_default_dtype(),
  1662. )
  1663. # Must set distributed attr for Tensor Parallel !
  1664. self.weight.is_distributed = (
  1665. True if (vocab_size != config.vocab_size) else False
  1666. )
  1667. if self.weight.is_distributed:
  1668. # for tie_word_embeddings
  1669. self.weight.split_axis = 0 if self.transpose_y else 1
  1670. def forward(self, hidden_states, tensor_parallel_output=None):
  1671. if tensor_parallel_output is None:
  1672. tensor_parallel_output = self.config.tensor_parallel_output
  1673. # 确保数据类型一致
  1674. if self.weight.dtype != hidden_states.dtype:
  1675. hidden_states = paddle.cast(hidden_states, self.weight.dtype)
  1676. logits = parallel_matmul(
  1677. hidden_states,
  1678. self.weight,
  1679. transpose_y=self.transpose_y,
  1680. tensor_parallel_output=tensor_parallel_output,
  1681. )
  1682. return logits
  1683. class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
  1684. _tied_weights_keys = ["lm_head.weight"]
  1685. def __init__(self, config, attn_implementation="flash_attention_2"):
  1686. super().__init__(config)
  1687. config._attn_implementation = attn_implementation
  1688. config.vision_config._attn_implementation = attn_implementation
  1689. self.visual = Qwen2VisionTransformerPretrainedModel._from_config(
  1690. config.vision_config
  1691. )
  1692. self.model = Qwen2VLModel(config)
  1693. self.vocab_size = config.vocab_size
  1694. if config.tie_word_embeddings:
  1695. self.lm_head = Qwen2LMHead(
  1696. config,
  1697. embedding_weights=self.model.embed_tokens.weight,
  1698. transpose_y=True,
  1699. )
  1700. self.tie_weights()
  1701. else:
  1702. self.lm_head = Qwen2LMHead(config)
  1703. self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
  1704. def get_input_embeddings(self):
  1705. return self.model.embed_tokens
  1706. def set_input_embeddings(self, value):
  1707. self.model.embed_tokens = value
  1708. def get_output_embeddings(self):
  1709. return self.lm_head
  1710. def set_output_embeddings(self, new_embeddings):
  1711. self.lm_head = new_embeddings
  1712. def set_decoder(self, decoder):
  1713. self.model = decoder
  1714. def get_decoder(self):
  1715. return self.model
  1716. @classmethod
  1717. def _get_tensor_parallel_mappings(cls, config: Qwen2VLConfig, is_split=True):
  1718. logging.info("Qwen2 inference model _get_tensor_parallel_mappings")
  1719. from paddlenlp.transformers.conversion_utils import split_or_merge_func
  1720. fn = split_or_merge_func(
  1721. is_split=is_split,
  1722. tensor_parallel_degree=config.tensor_parallel_degree,
  1723. tensor_parallel_rank=config.tensor_parallel_rank,
  1724. num_attention_heads=config.num_attention_heads,
  1725. )
  1726. def get_tensor_parallel_split_mappings(num_layers):
  1727. final_actions = {}
  1728. base_actions = {
  1729. "lm_head.weight": partial(fn, is_column=True),
  1730. # Row Linear
  1731. "embed_tokens.weight": partial(fn, is_column=False),
  1732. "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
  1733. "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
  1734. }
  1735. base_actions["layers.0.self_attn.q_proj.weight"] = partial(
  1736. fn, is_column=True
  1737. )
  1738. base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
  1739. # if we have enough num_key_value_heads to split, then split it.
  1740. if config.num_key_value_heads % config.tensor_parallel_degree == 0:
  1741. base_actions["layers.0.self_attn.k_proj.weight"] = partial(
  1742. fn, is_column=True
  1743. )
  1744. base_actions["layers.0.self_attn.v_proj.weight"] = partial(
  1745. fn, is_column=True
  1746. )
  1747. base_actions["layers.0.self_attn.k_proj.bias"] = partial(
  1748. fn, is_column=True
  1749. )
  1750. base_actions["layers.0.self_attn.v_proj.bias"] = partial(
  1751. fn, is_column=True
  1752. )
  1753. if config.fuse_attention_ffn:
  1754. base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
  1755. fn, is_column=True, is_naive_2fuse=True
  1756. )
  1757. else:
  1758. base_actions["layers.0.mlp.gate_proj.weight"] = partial(
  1759. fn, is_column=True
  1760. )
  1761. base_actions["layers.0.mlp.up_proj.weight"] = partial(
  1762. fn, is_column=True
  1763. )
  1764. for key, action in base_actions.items():
  1765. if "layers.0." in key:
  1766. for i in range(num_layers):
  1767. final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
  1768. final_actions[key] = action
  1769. return final_actions
  1770. mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
  1771. return mappings
  1772. @staticmethod
  1773. def get_rope_index(
  1774. spatial_merge_size,
  1775. image_token_id,
  1776. video_token_id,
  1777. vision_start_token_id,
  1778. input_ids: paddle.Tensor,
  1779. image_grid_thw: Optional[paddle.Tensor] = None,
  1780. video_grid_thw: Optional[paddle.Tensor] = None,
  1781. attention_mask: Optional[paddle.Tensor] = None,
  1782. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  1783. """
  1784. Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
  1785. Explanation:
  1786. Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
  1787. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
  1788. Examples:
  1789. input_ids: [T T T T T], here T is for text.
  1790. temporal position_ids: [0, 1, 2, 3, 4]
  1791. height position_ids: [0, 1, 2, 3, 4]
  1792. width position_ids: [0, 1, 2, 3, 4]
  1793. For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
  1794. and 1D rotary position embedding for text part.
  1795. Examples:
  1796. Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
  1797. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
  1798. vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
  1799. vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
  1800. vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
  1801. text temporal position_ids: [3, 4, 5, 6, 7]
  1802. text height position_ids: [3, 4, 5, 6, 7]
  1803. text width position_ids: [3, 4, 5, 6, 7]
  1804. Here we calculate the text start position_ids as the max vision position_ids plus 1.
  1805. Args:
  1806. input_ids (`paddle.Tensor` of shape `(batch_size, sequence_length)`):
  1807. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  1808. it.
  1809. image_grid_thw (`paddle.Tensor` of shape `(num_images, 3)`, *optional*):
  1810. The temporal, height and width of feature shape of each image in LLM.
  1811. video_grid_thw (`paddle.Tensor` of shape `(num_videos, 3)`, *optional*):
  1812. The temporal, height and width of feature shape of each video in LLM.
  1813. attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1814. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1815. - 1 for tokens that are **not masked**,
  1816. - 0 for tokens that are **masked**.
  1817. Returns:
  1818. position_ids (`paddle.Tensor` of shape `(3, batch_size, sequence_length)`)
  1819. mrope_position_deltas (`paddle.Tensor` of shape `(batch_size)`)
  1820. """
  1821. mrope_position_deltas = []
  1822. if image_grid_thw is not None or video_grid_thw is not None:
  1823. total_input_ids = input_ids
  1824. position_ids = paddle.ones(
  1825. [3, input_ids.shape[0], input_ids.shape[1]], dtype=input_ids.dtype
  1826. )
  1827. image_index, video_index = 0, 0
  1828. for i, input_ids in enumerate(total_input_ids):
  1829. # TODO: CUDA error in some paddle version
  1830. if attention_mask is not None:
  1831. input_ids = paddle.to_tensor(
  1832. input_ids.cpu()[attention_mask[i].cpu() == 1]
  1833. ) # NOTE 原始写法
  1834. image_nums, video_nums = 0, 0
  1835. vision_start_indices = paddle.nonzero(
  1836. input_ids == vision_start_token_id
  1837. ).squeeze(
  1838. 1
  1839. ) # NOTE 原始写法
  1840. vision_tokens = input_ids[vision_start_indices + 1]
  1841. image_nums = (
  1842. (vision_tokens == image_token_id).sum()
  1843. if vision_tokens.numel() > 0
  1844. else 0
  1845. )
  1846. video_nums = (
  1847. (vision_tokens == video_token_id).sum()
  1848. if vision_tokens.numel() > 0
  1849. else 0
  1850. )
  1851. input_tokens = input_ids.tolist()
  1852. llm_pos_ids_list: list = []
  1853. st = 0
  1854. remain_images, remain_videos = image_nums, video_nums
  1855. for _ in range(image_nums + video_nums):
  1856. if image_token_id in input_tokens and remain_images > 0:
  1857. ed_image = input_tokens.index(image_token_id, st)
  1858. else:
  1859. ed_image = len(input_tokens) + 1
  1860. if video_token_id in input_tokens and remain_videos > 0:
  1861. ed_video = input_tokens.index(video_token_id, st)
  1862. else:
  1863. ed_video = len(input_tokens) + 1
  1864. if ed_image < ed_video:
  1865. t, h, w = (
  1866. image_grid_thw[image_index][0],
  1867. image_grid_thw[image_index][1],
  1868. image_grid_thw[image_index][2],
  1869. )
  1870. image_index += 1
  1871. remain_images -= 1
  1872. ed = ed_image
  1873. else:
  1874. t, h, w = (
  1875. video_grid_thw[video_index][0],
  1876. video_grid_thw[video_index][1],
  1877. video_grid_thw[video_index][2],
  1878. )
  1879. video_index += 1
  1880. remain_videos -= 1
  1881. ed = ed_video
  1882. llm_grid_t, llm_grid_h, llm_grid_w = (
  1883. t.item(),
  1884. h.item() // spatial_merge_size,
  1885. w.item() // spatial_merge_size,
  1886. )
  1887. text_len = ed - st
  1888. st_idx = (
  1889. llm_pos_ids_list[-1].max() + 1
  1890. if len(llm_pos_ids_list) > 0
  1891. else 0
  1892. )
  1893. llm_pos_ids_list.append(
  1894. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1895. + st_idx
  1896. )
  1897. t_index = (
  1898. paddle.arange(llm_grid_t)
  1899. .reshape([-1, 1])
  1900. .expand([-1, llm_grid_h * llm_grid_w])
  1901. .flatten()
  1902. )
  1903. h_index = (
  1904. paddle.arange(llm_grid_h)
  1905. .reshape([1, -1, 1])
  1906. .expand([llm_grid_t, -1, llm_grid_w])
  1907. .flatten()
  1908. )
  1909. w_index = (
  1910. paddle.arange(llm_grid_w)
  1911. .reshape([1, 1, -1])
  1912. .expand([llm_grid_t, llm_grid_h, -1])
  1913. .flatten()
  1914. )
  1915. llm_pos_ids_list.append(
  1916. paddle.stack([t_index, h_index, w_index]) + text_len + st_idx
  1917. )
  1918. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  1919. if st < len(input_tokens):
  1920. st_idx = (
  1921. llm_pos_ids_list[-1].max() + 1
  1922. if len(llm_pos_ids_list) > 0
  1923. else 0
  1924. )
  1925. text_len = len(input_tokens) - st
  1926. llm_pos_ids_list.append(
  1927. paddle.arange(text_len).reshape([1, -1]).expand([3, -1])
  1928. + st_idx
  1929. )
  1930. llm_positions = paddle.concat(llm_pos_ids_list, axis=1).reshape([3, -1])
  1931. if _IS_NPU:
  1932. bool_indices = (
  1933. (attention_mask[i] == 1)
  1934. .unsqueeze(0)
  1935. .tile([position_ids.shape[0], 1])
  1936. )
  1937. position_ids[:, i] = paddle.index_put(
  1938. position_ids[:, i], [bool_indices], llm_positions.reshape([-1])
  1939. )
  1940. else:
  1941. position_ids[..., i, attention_mask[i] == 1] = llm_positions
  1942. mrope_position_deltas.append(
  1943. llm_positions.max() + 1 - len(total_input_ids[i])
  1944. )
  1945. mrope_position_deltas = paddle.to_tensor(mrope_position_deltas).unsqueeze(1)
  1946. else:
  1947. if attention_mask is not None:
  1948. position_ids = paddle.cast(attention_mask, dtype="int64").cumsum(-1) - 1
  1949. position_ids.masked_fill_(mask=attention_mask == 0, value=1)
  1950. position_ids = position_ids.unsqueeze(0).expand([3, -1, -1])
  1951. max_position_ids = position_ids.max(0, keepdim=False)[0].max(
  1952. -1, keepdim=True
  1953. )[0]
  1954. mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
  1955. else:
  1956. position_ids = (
  1957. paddle.arange(input_ids.shape[1])
  1958. .reshape([1, 1, -1])
  1959. .expand(shape=[3, input_ids.shape[0], -1])
  1960. )
  1961. mrope_position_deltas = paddle.zeros(
  1962. [input_ids.shape[0], 1], dtype=input_ids.dtype
  1963. )
  1964. return position_ids, mrope_position_deltas
  1965. def update_model_kwargs_for_generation(
  1966. self,
  1967. outputs: ModelOutput,
  1968. model_kwargs: Dict[str, Any],
  1969. is_encoder_decoder: bool = False,
  1970. # num_new_tokens: int = 1,
  1971. ) -> Dict[str, Any]:
  1972. model_kwargs = super().update_model_kwargs_for_generation(
  1973. outputs=outputs,
  1974. model_kwargs=model_kwargs,
  1975. is_encoder_decoder=is_encoder_decoder,
  1976. )
  1977. if getattr(outputs, "rope_deltas", None) is not None:
  1978. model_kwargs["rope_deltas"] = outputs.rope_deltas
  1979. return model_kwargs
  1980. def vision_forward(
  1981. self,
  1982. input_ids: paddle.Tensor,
  1983. inputs_embeds: Optional[paddle.Tensor] = None,
  1984. attention_mask: Optional[paddle.Tensor] = None,
  1985. position_ids: Optional[paddle.Tensor] = None,
  1986. pixel_values: Optional[paddle.Tensor] = None,
  1987. pixel_values_videos: Optional[paddle.Tensor] = None,
  1988. image_grid_thw: Optional[paddle.Tensor] = None,
  1989. video_grid_thw: Optional[paddle.Tensor] = None,
  1990. rope_deltas: Optional[paddle.Tensor] = None,
  1991. ):
  1992. if inputs_embeds is None:
  1993. from paddlenlp.experimental.transformers.qwen2.modeling import (
  1994. Qwen2VLForConditionalGenerationBlockInferenceModel,
  1995. )
  1996. assert isinstance(
  1997. self.model, Qwen2VLForConditionalGenerationBlockInferenceModel
  1998. ), "model is not an instance of Qwen2VLForConditionalGenerationBlockInferenceModel"
  1999. inputs_embeds = self.model.qwen2.embed_tokens(input_ids)
  2000. if pixel_values is not None:
  2001. pixel_values = paddle.cast(pixel_values, paddle.bfloat16)
  2002. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2003. image_mask = input_ids == self.config.image_token_id
  2004. inputs_embeds[image_mask] = image_embeds
  2005. if pixel_values_videos is not None:
  2006. pixel_values_videos = paddle.cast(pixel_values_videos, paddle.bfloat16)
  2007. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2008. video_mask = input_ids == self.config.video_token_id
  2009. inputs_embeds[video_mask] = video_embeds
  2010. return inputs_embeds
  2011. def forward(
  2012. self,
  2013. input_ids: paddle.Tensor = None,
  2014. attention_mask: Optional[paddle.Tensor] = None,
  2015. position_ids: Optional[paddle.Tensor] = None,
  2016. past_key_values: Optional[List[paddle.Tensor]] = None,
  2017. inputs_embeds: Optional[paddle.Tensor] = None,
  2018. labels: Optional[paddle.Tensor] = None,
  2019. use_cache: Optional[bool] = None,
  2020. output_attentions: Optional[bool] = None,
  2021. output_hidden_states: Optional[bool] = None,
  2022. return_dict: Optional[bool] = None,
  2023. pixel_values: Optional[paddle.Tensor] = None,
  2024. pixel_values_videos: Optional[paddle.Tensor] = None,
  2025. image_grid_thw: Optional[paddle.Tensor] = None,
  2026. video_grid_thw: Optional[paddle.Tensor] = None,
  2027. rope_deltas: Optional[paddle.Tensor] = None,
  2028. ):
  2029. """
  2030. Args:
  2031. labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  2032. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  2033. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  2034. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  2035. """
  2036. output_attentions = (
  2037. output_attentions
  2038. if output_attentions is not None
  2039. else self.config.output_attentions
  2040. )
  2041. output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip
  2042. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  2043. if inputs_embeds is None:
  2044. inputs_embeds = self.model.embed_tokens(input_ids)
  2045. if pixel_values is not None:
  2046. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  2047. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2048. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  2049. image_mask = input_ids == self.config.image_token_id
  2050. if self.training:
  2051. inputs_embeds = inputs_embeds.clone()
  2052. inputs_embeds[image_mask] = image_embeds
  2053. if pixel_values_videos is not None:
  2054. pixel_values_videos = paddle.cast(
  2055. pixel_values_videos, inputs_embeds.dtype
  2056. )
  2057. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2058. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  2059. video_mask = input_ids == self.config.video_token_id
  2060. inputs_embeds[video_mask] = video_embeds
  2061. if attention_mask is not None:
  2062. attention_mask = attention_mask
  2063. outputs = self.model(
  2064. input_ids=None,
  2065. position_ids=position_ids,
  2066. attention_mask=attention_mask,
  2067. past_key_values=past_key_values,
  2068. inputs_embeds=inputs_embeds,
  2069. use_cache=use_cache,
  2070. output_attentions=output_attentions,
  2071. output_hidden_states=output_hidden_states,
  2072. return_dict=return_dict,
  2073. )
  2074. hidden_states = outputs[0]
  2075. tensor_parallel_output = (
  2076. self.config.tensor_parallel_output
  2077. and self.config.tensor_parallel_degree > 1
  2078. )
  2079. logits = self.lm_head(
  2080. hidden_states, tensor_parallel_output=tensor_parallel_output
  2081. )
  2082. logits = paddle.cast(logits, "float32")
  2083. loss = None
  2084. if labels is not None:
  2085. # Shift so that tokens < n predict n
  2086. shift_logits = logits[..., :-1, :]
  2087. shift_labels = labels[..., 1:]
  2088. # Flatten the tokens
  2089. shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
  2090. shift_labels = shift_labels.reshape([-1])
  2091. if _IS_NPU:
  2092. tmp = F.log_softmax(shift_logits, axis=1)
  2093. loss = F.nll_loss(tmp, shift_labels, reduction="sum")
  2094. else:
  2095. loss_fct = nn.CrossEntropyLoss(reduction="sum")
  2096. loss = loss_fct(shift_logits, shift_labels)
  2097. label_sum = paddle.sum(shift_labels != -100).cast("float32")
  2098. loss = loss / label_sum
  2099. if not return_dict:
  2100. output = (logits,) + tuple(outputs[1:])
  2101. return (loss,) + output if loss is not None else output
  2102. return Qwen2VLCausalLMOutputWithPast(
  2103. loss=loss,
  2104. logits=logits,
  2105. past_key_values=outputs.past_key_values,
  2106. hidden_states=outputs.hidden_states,
  2107. attentions=outputs.attentions,
  2108. rope_deltas=rope_deltas,
  2109. )
  2110. def prepare_inputs_for_generation(
  2111. self,
  2112. input_ids,
  2113. past_key_values=None,
  2114. attention_mask=None,
  2115. inputs_embeds=None,
  2116. cache_position=None,
  2117. position_ids=None,
  2118. use_cache=True,
  2119. pixel_values=None,
  2120. pixel_values_videos=None,
  2121. image_grid_thw=None,
  2122. video_grid_thw=None,
  2123. **kwargs,
  2124. ):
  2125. batch_size, seq_length = input_ids.shape
  2126. if past_key_values is None:
  2127. cache_position = paddle.arange(input_ids.shape[1])
  2128. else:
  2129. cache_position = paddle.to_tensor([seq_length - 1])
  2130. if past_key_values is not None:
  2131. input_ids = input_ids[:, -1].unsqueeze(-1)
  2132. rope_deltas = kwargs.get("rope_deltas", None)
  2133. if attention_mask is not None and position_ids is None:
  2134. if cache_position is None or (
  2135. cache_position is not None and cache_position[0] == 0
  2136. ):
  2137. position_ids, rope_deltas = self.get_rope_index(
  2138. self.config.vision_config.spatial_merge_size,
  2139. self.config.image_token_id,
  2140. self.config.video_token_id,
  2141. self.config.vision_start_token_id,
  2142. input_ids,
  2143. image_grid_thw,
  2144. video_grid_thw,
  2145. attention_mask,
  2146. )
  2147. else:
  2148. batch_size, seq_length = input_ids.shape
  2149. delta = (
  2150. cache_position[0] + rope_deltas
  2151. if cache_position is not None and rope_deltas is not None
  2152. else 0
  2153. )
  2154. position_ids = paddle.arange(seq_length)
  2155. position_ids = position_ids.reshape([1, -1]).expand([batch_size, -1])
  2156. position_ids = position_ids + delta
  2157. position_ids = position_ids.unsqueeze(axis=0).expand([3, -1, -1])
  2158. if cache_position[0] != 0:
  2159. pixel_values = None
  2160. pixel_values_videos = None
  2161. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  2162. if inputs_embeds is not None and cache_position[0] == 0:
  2163. model_inputs = {"inputs_embeds": inputs_embeds}
  2164. else:
  2165. model_inputs = {"input_ids": input_ids}
  2166. model_inputs.update(
  2167. {
  2168. "position_ids": position_ids, # [3, 1, 3602]
  2169. "past_key_values": past_key_values, # DynamicCache()
  2170. "use_cache": use_cache, # 1
  2171. "attention_mask": attention_mask, # [1, 3602]
  2172. "pixel_values": pixel_values, # [14308, 1176]
  2173. "pixel_values_videos": pixel_values_videos,
  2174. "image_grid_thw": image_grid_thw, # [[ 1, 98, 146]]
  2175. "video_grid_thw": video_grid_thw,
  2176. "rope_deltas": rope_deltas, # [[-3504]]
  2177. }
  2178. )
  2179. return model_inputs
  2180. def gme_qwen2_vl_forward(
  2181. self,
  2182. input_ids: paddle.Tensor = None,
  2183. attention_mask: Optional[paddle.Tensor] = None,
  2184. position_ids: Optional[paddle.Tensor] = None,
  2185. past_key_values: Optional[List[paddle.Tensor]] = None,
  2186. inputs_embeds: Optional[paddle.Tensor] = None,
  2187. labels: Optional[paddle.Tensor] = None,
  2188. use_cache: Optional[bool] = None,
  2189. output_attentions: Optional[bool] = None,
  2190. output_hidden_states: Optional[bool] = None,
  2191. return_dict: Optional[bool] = None,
  2192. pixel_values: Optional[paddle.Tensor] = None,
  2193. pixel_values_videos: Optional[paddle.Tensor] = None,
  2194. image_grid_thw: Optional[paddle.Tensor] = None,
  2195. video_grid_thw: Optional[paddle.Tensor] = None,
  2196. rope_deltas: Optional[paddle.Tensor] = None,
  2197. ):
  2198. output_attentions = (
  2199. output_attentions
  2200. if output_attentions is not None
  2201. else self.config.output_attentions
  2202. )
  2203. output_hidden_states = (
  2204. output_hidden_states
  2205. if output_hidden_states is not None
  2206. else self.config.output_hidden_states
  2207. )
  2208. return_dict = True # return_dict if return_dict is not None else self.config.use_return_dict
  2209. if inputs_embeds is None:
  2210. inputs_embeds = self.model.embed_tokens(input_ids)
  2211. if pixel_values is not None:
  2212. # 确保 pixel_values 和 inputs_embeds 使用相同的数据类型
  2213. pixel_values = paddle.cast(pixel_values, inputs_embeds.dtype)
  2214. image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
  2215. # 确保 image_embeds 和 inputs_embeds 使用相同的数据类型
  2216. image_embeds = paddle.cast(image_embeds, inputs_embeds.dtype)
  2217. image_mask = input_ids == self.config.image_token_id
  2218. if self.training:
  2219. inputs_embeds = inputs_embeds.clone()
  2220. inputs_embeds[image_mask] = image_embeds
  2221. if pixel_values_videos is not None:
  2222. # 确保 pixel_values_videos 和 inputs_embeds 使用相同的数据类型
  2223. pixel_values_videos = paddle.cast(
  2224. pixel_values_videos, inputs_embeds.dtype
  2225. )
  2226. video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
  2227. # 确保 video_embeds 和 inputs_embeds 使用相同的数据类型
  2228. video_embeds = paddle.cast(video_embeds, inputs_embeds.dtype)
  2229. video_mask = input_ids == self.config.video_token_id
  2230. inputs_embeds[video_mask] = video_embeds
  2231. if attention_mask is not None:
  2232. attention_mask = attention_mask
  2233. outputs = self.model(
  2234. input_ids=None,
  2235. position_ids=position_ids,
  2236. attention_mask=attention_mask,
  2237. past_key_values=past_key_values,
  2238. inputs_embeds=inputs_embeds,
  2239. use_cache=use_cache,
  2240. output_attentions=output_attentions,
  2241. output_hidden_states=output_hidden_states,
  2242. return_dict=return_dict,
  2243. )
  2244. hidden_states = outputs[0]
  2245. # get last hidden state
  2246. last_hidden_state = hidden_states[:, -1, :]
  2247. return last_hidden_state
  2248. class PPDocBeeInference(Qwen2VLForConditionalGeneration):
  2249. def generate(self, inputs, **kwargs):
  2250. max_new_tokens = kwargs.get("max_new_tokens", 2048)
  2251. temperature = kwargs.get("temperature", 0.1)
  2252. top_p = kwargs.get("top_p", 0.001)
  2253. top_k = kwargs.get("top_k", 1)
  2254. with paddle.no_grad():
  2255. generated_ids = super().generate(
  2256. **inputs,
  2257. max_new_tokens=max_new_tokens,
  2258. temperature=temperature,
  2259. top_p=top_p,
  2260. top_k=top_k,
  2261. )
  2262. return generated_ids