BK-Lee commited on
Commit
eacf0bd
1 Parent(s): 758b722
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: TroL
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
1
  ---
2
  title: TroL
3
+ emoji: 👽
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A100 Zero GPU
2
+ import spaces
3
+
4
+ # TroL Package
5
+ import torch
6
+ from PIL import Image
7
+ from utils.utils import *
8
+ import torch.nn.functional as F
9
+ from trol.load_trol import load_trol
10
+ from torchvision.transforms.functional import pil_to_tensor
11
+
12
+ # Gradio Package
13
+ import time
14
+ import gradio as gr
15
+ from threading import Thread
16
+ from accelerate import Accelerator
17
+ from transformers import TextIteratorStreamer
18
+ from torchvision.transforms.functional import pil_to_tensor
19
+
20
+ # flash attention
21
+ import subprocess
22
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
+
24
+ # accel
25
+ accel = Accelerator()
26
+
27
+ # model selection
28
+ link = "TroL-7B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'
29
+
30
+ # User prompt
31
+ prompt_type="with_image" # Select one option "text_only", "with_image"
32
+ img_path='figures/demo.png'
33
+ question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
34
+
35
+ # loading model
36
+ model, tokenizer = load_trol(link=link)
37
+
38
+ # cpu -> gpu
39
+ for param in model.parameters():
40
+ if not param.is_cuda:
41
+ param.data = param.to('cuda:0')
42
+
43
+ def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):
44
+
45
+ # propagation
46
+ _inputs = model.eval_process(inputs=inputs,
47
+ data='demo',
48
+ tokenizer=tokenizer,
49
+ device=device,
50
+ img_token_number=image_token_number)
51
+ generation_kwargs = _inputs
52
+ generation_kwargs.update({'streamer': streamer})
53
+ generation_kwargs.update({'do_sample': True})
54
+ generation_kwargs.update({'max_new_tokens': new_max_token})
55
+ generation_kwargs.update({'top_p': top_p})
56
+ generation_kwargs.update({'temperature': temperature})
57
+ generation_kwargs.update({'use_cache': True})
58
+ return model.generate(**generation_kwargs)
59
+
60
+ @spaces.GPU
61
+ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
62
+
63
+ try:
64
+ # prompt type -> input prompt
65
+ image_token_number = None
66
+ if len(message['files']) != 0:
67
+ # Image Load
68
+ image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
69
+ if not "3.8B" in link:
70
+ image_token_number = 1225
71
+ image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
72
+ inputs = [{'image': image, 'question': message['text']}]
73
+
74
+ else:
75
+ inputs = [{'question': message['text']}]
76
+
77
+ # Text Generation
78
+ with torch.inference_mode():
79
+ # kwargs
80
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
81
+
82
+ # Threading generation
83
+ thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
84
+ image_token_number=image_token_number,
85
+ streamer=streamer,
86
+ device=accel.device,
87
+ temperature=temperature,
88
+ new_max_token=new_max_token,
89
+ top_p=top_p))
90
+ thread.start()
91
+
92
+ # generated text
93
+ generated_text = ""
94
+ for new_text in streamer:
95
+ generated_text += new_text
96
+ generated_text
97
+
98
+ # Text decoding
99
+ response = output_filtering(generated_text, model)
100
+
101
+ except:
102
+ response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version."
103
+
104
+ # private log print
105
+ text = message['text']
106
+ files = message['files']
107
+ print(f'Text: {text}')
108
+ print(f'MM Files: {files}')
109
+
110
+
111
+ buffer = ""
112
+ for character in response:
113
+ buffer += character
114
+ time.sleep(0.015)
115
+ yield buffer
116
+
117
+ demo = gr.ChatInterface(fn=bot_streaming,
118
+ additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
119
+ additional_inputs_accordion="Generation Hyperparameters",
120
+ theme=gr.themes.Soft(),
121
+ title="☄️Meteor",
122
+ description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale.\n"
123
+ "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
124
+ stop_btn="Stop Generation", multimodal=True)
125
+ demo.launch()
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+ # Generate
139
+ with torch.inference_mode():
140
+ _inputs = model.eval_process(inputs=inputs,
141
+ data='demo',
142
+ tokenizer=tokenizer,
143
+ device='cuda:0',
144
+ img_token_number=image_token_number)
145
+ generate_ids = model.generate(**_inputs, max_new_tokens=256, use_cache=True)
146
+ response = output_filtering(tokenizer.batch_decode(generate_ids, skip_special_tokens=False)[0], model)
147
+ print(response)
config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Checkpoints & Dataset root
2
+ TROL_1_8B="BK-Lee/TroL-1.8B"
3
+ TROL_3_8B="BK-Lee/TroL-3.8B"
4
+ TROL_7B="BK-Lee/TroL-7B"
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ bitsandbytes
3
+ accelerate
4
+ peft
5
+ pandas
6
+ pyarrow
7
+ jsonlines
8
+ wandb
9
+ einops
10
+ timm
11
+ einops_exts
12
+ sentencepiece
13
+ shortuuid
14
+ seaborn
15
+ matplotlib
16
+ scikit-learn
17
+ word2number
18
+ Rouge
19
+ gradio
20
+ spaces
trol/arch_internlm2/build_module.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPVisionModel
6
+
7
+
8
+ def build_vision_tower():
9
+ vision_tower = 'openai/clip-vit-large-patch14-336'
10
+ return CLIPVisionTower(vision_tower)
11
+
12
+
13
+ def build_vision_projector(hidden_size):
14
+ projector_type = 'mlp2x_gelu'
15
+ mm_hidden_size = 1024
16
+
17
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
18
+ if mlp_gelu_match:
19
+ mlp_depth = int(mlp_gelu_match.group(1))
20
+ modules = [nn.Linear(mm_hidden_size, hidden_size)]
21
+ for _ in range(1, mlp_depth):
22
+ modules.append(nn.GELU())
23
+ modules.append(nn.Linear(hidden_size, hidden_size))
24
+ return nn.Sequential(*modules)
25
+
26
+ if projector_type == 'identity':
27
+ return IdentityMap()
28
+
29
+ raise ValueError(f'Unknown projector type: {projector_type}')
30
+
31
+
32
+ class IdentityMap(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ def forward(self, x, *args, **kwargs):
38
+ return x
39
+
40
+ @property
41
+ def config(self):
42
+ return {'mm_projector_type': 'identity'}
43
+
44
+
45
+ class CLIPVisionTower(nn.Module):
46
+
47
+ def __init__(self, vision_tower):
48
+ super().__init__()
49
+
50
+ self.is_loaded = False
51
+ self.is_resize_pos = False
52
+
53
+ self.vision_tower_name = vision_tower
54
+ self.select_layer = -1
55
+ self.select_feature = 'patch'
56
+ self.load_model()
57
+ self.resize_pos()
58
+
59
+ def load_model(self):
60
+ self.vision_tower = CLIPVisionModel.from_pretrained(
61
+ self.vision_tower_name)
62
+ self.vision_tower.requires_grad_(False)
63
+
64
+ self.is_loaded = True
65
+
66
+ def resize_pos(self):
67
+ pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight
68
+ pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0)
69
+ orig_size = 24
70
+ new_size = 35
71
+
72
+ if pos_embed_checkpoint.shape[1] == new_size**2 + 1:
73
+ self.is_resize_pos = True
74
+ else:
75
+ embedding_size = pos_embed_checkpoint.shape[-1]
76
+ num_extra_tokens = 1
77
+ new_num = new_size**2 + num_extra_tokens
78
+ # print('Position interpolate from %dx%d to %dx%d' %
79
+ # (orig_size, orig_size, new_size, new_size))
80
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
81
+ # only the position tokens are interpolated
82
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
83
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
84
+ embedding_size).permute(
85
+ 0, 3, 1, 2)
86
+ pos_tokens = torch.nn.functional.interpolate(
87
+ pos_tokens,
88
+ size=(new_size, new_size),
89
+ mode='bicubic',
90
+ align_corners=False)
91
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
92
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
93
+
94
+ new_pos_embed = new_pos_embed.squeeze(0)
95
+
96
+ self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(
97
+ new_num, 1024)
98
+ self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(
99
+ new_pos_embed.to(pos_embed_checkpoint.dtype))
100
+ self.vision_tower.vision_model.embeddings.position_ids = torch.arange(
101
+ new_num).expand((1, -1))
102
+
103
+ self.is_resize_pos = True
104
+
105
+ def feature_select(self, image_forward_outs):
106
+ image_features = image_forward_outs.hidden_states[self.select_layer]
107
+ if self.select_feature == 'patch':
108
+ image_features = image_features[:, 1:]
109
+ elif self.select_feature == 'cls_patch':
110
+ image_features = image_features
111
+ else:
112
+ raise ValueError(
113
+ f'Unexpected select feature: {self.select_feature}')
114
+ return image_features
115
+
116
+ def forward(self, images):
117
+ if not self.is_loaded:
118
+ self.load_model()
119
+ if type(images) is list:
120
+ image_features = []
121
+ for image in images:
122
+ image_forward_out = self.vision_tower(
123
+ image.to(device=self.device,
124
+ dtype=self.dtype).unsqueeze(0),
125
+ output_hidden_states=True)
126
+ image_feature = self.feature_select(image_forward_out).to(
127
+ image.dtype)
128
+ image_features.append(image_feature)
129
+ else:
130
+ image_forward_outs = self.vision_tower(
131
+ images.to(device=self.device, dtype=self.dtype),
132
+ output_hidden_states=True)
133
+ image_features = self.feature_select(image_forward_outs).to(
134
+ images.dtype)
135
+
136
+ return image_features
137
+
138
+ @property
139
+ def dummy_feature(self):
140
+ return torch.zeros(
141
+ 1, self.hidden_size, device=self.device, dtype=self.dtype)
142
+
143
+ @property
144
+ def dtype(self):
145
+ return self.vision_tower.dtype
146
+
147
+ @property
148
+ def device(self):
149
+ return self.vision_tower.device
150
+
151
+ @property
152
+ def config(self):
153
+ if self.is_loaded:
154
+ return self.vision_tower.config
155
+ else:
156
+ return self.cfg_only
157
+
158
+ @property
159
+ def hidden_size(self):
160
+ return self.config.hidden_size
161
+
162
+ @property
163
+ def num_patches(self):
164
+ return (self.config.image_size // self.config.patch_size)**2
165
+
166
+
167
+ class LoRA(nn.Module):
168
+
169
+ def __init__(self,
170
+ in_features: int,
171
+ out_features: int,
172
+ bias: bool = True,
173
+ device=None,
174
+ dtype=None,
175
+ lora_r=8,
176
+ lora_alpha=16,
177
+ lora_dropout=0.05,
178
+ lora_len=0,
179
+ **kwargs) -> None:
180
+ super().__init__()
181
+ self.lora_r = lora_r
182
+ self.lora_alpha = lora_alpha
183
+ self.lora_len = lora_len
184
+ if lora_dropout > 0.:
185
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
186
+ else:
187
+ self.lora_dropout = lambda x: x
188
+ self.lora_scaling = self.lora_alpha / self.lora_r
189
+
190
+ self.lora_A = nn.Linear(
191
+ in_features, self.lora_r, bias=False, device=device, dtype=dtype)
192
+ self.lora_B = nn.Linear(
193
+ self.lora_r, out_features, bias=False, device=device, dtype=dtype)
194
+ self.ffn = nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
195
+
196
+ def forward(self, x, im_mask=None):
197
+ res = self.ffn(x)
198
+ if im_mask is not None:
199
+ if torch.sum(im_mask) > 0:
200
+ part_x = x[im_mask]
201
+ res[im_mask] += self.lora_B(
202
+ self.lora_A(
203
+ self.lora_dropout(part_x))) * self.lora_scaling
204
+ else:
205
+ part_x = x[:, :1]
206
+ res[:, :1] += self.lora_B(
207
+ self.lora_A(self.lora_dropout(part_x))) * 0
208
+ return res
trol/arch_internlm2/configuration_internlm2.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) InternLM. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """InternLM model configuration."""
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
27
+
28
+
29
+ class InternLM2Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate
32
+ an InternLM model according to the specified arguments, defining the model architecture. Instantiating a
33
+ configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32000):
41
+ Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`InternLMModel`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 11008):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
63
+ just in case (e.g., 512 or 1024 or 2048).
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
67
+ The epsilon used by the rms normalization layers.
68
+ use_cache (`bool`, *optional*, defaults to `True`):
69
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
70
+ relevant if `config.is_decoder=True`.
71
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import InternLMModel, InternLMConfig
77
+
78
+ >>> # Initializing a InternLM internlm-7b style configuration
79
+ >>> configuration = InternLMConfig()
80
+
81
+ >>> # Initializing a model from the internlm-7b style configuration
82
+ >>> model = InternLMModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+ model_type = 'internlm'
88
+ _auto_class = 'AutoConfig'
89
+
90
+ def __init__( # pylint: disable=W0102
91
+ self,
92
+ vocab_size=103168,
93
+ hidden_size=4096,
94
+ intermediate_size=11008,
95
+ num_hidden_layers=32,
96
+ num_attention_heads=32,
97
+ num_key_value_heads=None,
98
+ hidden_act='silu',
99
+ max_position_embeddings=2048,
100
+ initializer_range=0.02,
101
+ rms_norm_eps=1e-6,
102
+ use_cache=True,
103
+ pad_token_id=0,
104
+ bos_token_id=1,
105
+ eos_token_id=2,
106
+ tie_word_embeddings=False,
107
+ bias=True,
108
+ rope_theta=10000,
109
+ rope_scaling=None,
110
+ attn_implementation='eager',
111
+ **kwargs,
112
+ ):
113
+ self.vocab_size = vocab_size
114
+ self.max_position_embeddings = max_position_embeddings
115
+ self.hidden_size = hidden_size
116
+ self.intermediate_size = intermediate_size
117
+ self.num_hidden_layers = num_hidden_layers
118
+ self.num_attention_heads = num_attention_heads
119
+ self.bias = bias
120
+
121
+ if num_key_value_heads is None:
122
+ num_key_value_heads = num_attention_heads
123
+ self.num_key_value_heads = num_key_value_heads
124
+
125
+ self.hidden_act = hidden_act
126
+ self.initializer_range = initializer_range
127
+ self.rms_norm_eps = rms_norm_eps
128
+ self.use_cache = use_cache
129
+ self.rope_theta = rope_theta
130
+ self.rope_scaling = rope_scaling
131
+ self._rope_scaling_validation()
132
+
133
+ self.attn_implementation = attn_implementation
134
+ if self.attn_implementation is None:
135
+ self.attn_implementation = 'eager'
136
+ super().__init__(
137
+ pad_token_id=pad_token_id,
138
+ bos_token_id=bos_token_id,
139
+ eos_token_id=eos_token_id,
140
+ tie_word_embeddings=tie_word_embeddings,
141
+ **kwargs,
142
+ )
143
+
144
+ def _rope_scaling_validation(self):
145
+ """Validate the `rope_scaling` configuration."""
146
+ if self.rope_scaling is None:
147
+ return
148
+
149
+ if not isinstance(self.rope_scaling,
150
+ dict) or len(self.rope_scaling) != 2:
151
+ raise ValueError(
152
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
153
+ f'got {self.rope_scaling}')
154
+ rope_scaling_type = self.rope_scaling.get('type', None)
155
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
156
+ if rope_scaling_type is None or rope_scaling_type not in [
157
+ 'linear', 'dynamic'
158
+ ]:
159
+ raise ValueError(
160
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
161
+ )
162
+ if rope_scaling_factor is None or not isinstance(
163
+ rope_scaling_factor, float) or rope_scaling_factor < 1.0:
164
+ raise ValueError(
165
+ f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
166
+ )
trol/arch_internlm2/modeling_internlm2.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Copyright (c) InternLM. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """PyTorch InternLM2 model."""
20
+ import math
21
+ import warnings
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from einops import rearrange
28
+ from torch import nn
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (add_start_docstrings,
33
+ add_start_docstrings_to_model_forward, logging)
34
+
35
+ try:
36
+ from transformers.generation.streamers import BaseStreamer
37
+ except: # noqa # pylint: disable=bare-except
38
+ BaseStreamer = None
39
+
40
+ from .build_module import LoRA
41
+ from .configuration_internlm2 import InternLM2Config
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CONFIG_FOR_DOC = 'InternLM2Config'
46
+ flash_attn_func, flash_attn_varlen_func = None, None
47
+ pad_input, index_first_axis, unpad_input = None, None, None
48
+ def _import_flash_attn():
49
+ global flash_attn_func, flash_attn_varlen_func
50
+ global pad_input, index_first_axis, unpad_input
51
+ try:
52
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
+ except ImportError:
57
+ raise ImportError("flash_attn is not installed.")
58
+
59
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
+ def _get_unpad_data(attention_mask):
61
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
62
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
63
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
64
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
65
+ return (
66
+ indices,
67
+ cu_seqlens,
68
+ max_seqlen_in_batch,
69
+ )
70
+
71
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
72
+ def _make_causal_mask(input_ids_shape: torch.Size,
73
+ dtype: torch.dtype,
74
+ device: torch.device,
75
+ past_key_values_length: int = 0):
76
+ """Make causal mask used for bi-directional self-attention."""
77
+ bsz, tgt_len = input_ids_shape
78
+ mask = torch.full((tgt_len, tgt_len),
79
+ torch.tensor(torch.finfo(dtype).min, device=device),
80
+ device=device)
81
+ mask_cond = torch.arange(mask.size(-1), device=device)
82
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
83
+ mask = mask.to(dtype)
84
+
85
+ if past_key_values_length > 0:
86
+ mask = torch.cat([
87
+ torch.zeros(
88
+ tgt_len, past_key_values_length, dtype=dtype, device=device),
89
+ mask
90
+ ],
91
+ dim=-1)
92
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len,
93
+ tgt_len + past_key_values_length)
94
+
95
+
96
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
97
+ def _expand_mask(mask: torch.Tensor,
98
+ dtype: torch.dtype,
99
+ tgt_len: Optional[int] = None):
100
+ """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
101
+ src_seq_len]`."""
102
+ bsz, src_len = mask.size()
103
+ tgt_len = tgt_len if tgt_len is not None else src_len
104
+
105
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
106
+ src_len).to(dtype)
107
+
108
+ inverted_mask = 1.0 - expanded_mask
109
+
110
+ return inverted_mask.masked_fill(
111
+ inverted_mask.to(torch.bool),
112
+ torch.finfo(dtype).min)
113
+
114
+
115
+ class InternLM2RMSNorm(nn.Module):
116
+
117
+ def __init__(self, hidden_size, eps=1e-6):
118
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
119
+ super().__init__()
120
+ self.weight = nn.Parameter(torch.ones(hidden_size))
121
+ self.variance_epsilon = eps
122
+
123
+ def forward(self, hidden_states):
124
+ input_dtype = hidden_states.dtype
125
+ hidden_states = hidden_states.to(torch.float32)
126
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
127
+ hidden_states = hidden_states * torch.rsqrt(variance +
128
+ self.variance_epsilon)
129
+ return self.weight * hidden_states.to(input_dtype)
130
+
131
+
132
+ class InternLM2RotaryEmbedding(nn.Module):
133
+
134
+ def __init__(self,
135
+ dim,
136
+ max_position_embeddings=2048,
137
+ base=10000,
138
+ device=None):
139
+ super().__init__()
140
+
141
+ self.dim = dim
142
+ self.max_position_embeddings = max_position_embeddings
143
+ self.base = base
144
+ inv_freq = 1.0 / (
145
+ self.base
146
+ **(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
147
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
148
+
149
+ # Build here to make `torch.jit.trace` work.
150
+ self._set_cos_sin_cache(
151
+ seq_len=max_position_embeddings,
152
+ device=self.inv_freq.device,
153
+ dtype=torch.get_default_dtype())
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(
158
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
159
+
160
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
161
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
162
+ emb = torch.cat((freqs, freqs), dim=-1)
163
+ self.register_buffer(
164
+ 'cos_cached', emb.cos().to(dtype), persistent=False)
165
+ self.register_buffer(
166
+ 'sin_cached', emb.sin().to(dtype), persistent=False)
167
+
168
+ def forward(self, x, seq_len=None):
169
+ # x: [bs, num_attention_heads, seq_len, head_size]
170
+ if seq_len > self.max_seq_len_cached:
171
+ self._set_cos_sin_cache(
172
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
173
+
174
+ return (
175
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
176
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
177
+ )
178
+
179
+
180
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
181
+ """InternLM2RotaryEmbedding extended with linear scaling.
182
+
183
+ Credits to the Reddit user /u/kaiokendev
184
+ """
185
+
186
+ def __init__(self,
187
+ dim,
188
+ max_position_embeddings=2048,
189
+ base=10000,
190
+ device=None,
191
+ scaling_factor=1.0):
192
+ self.scaling_factor = scaling_factor
193
+ super().__init__(dim, max_position_embeddings, base, device)
194
+
195
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
196
+ self.max_seq_len_cached = seq_len
197
+ t = torch.arange(
198
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
199
+ t = t / self.scaling_factor
200
+
201
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
202
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
+ emb = torch.cat((freqs, freqs), dim=-1)
204
+ self.register_buffer(
205
+ 'cos_cached', emb.cos().to(dtype), persistent=False)
206
+ self.register_buffer(
207
+ 'sin_cached', emb.sin().to(dtype), persistent=False)
208
+
209
+
210
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
211
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
212
+
213
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
214
+ """
215
+
216
+ def __init__(self,
217
+ dim,
218
+ max_position_embeddings=2048,
219
+ base=10000,
220
+ device=None,
221
+ scaling_factor=1.0):
222
+ self.scaling_factor = scaling_factor
223
+ super().__init__(dim, max_position_embeddings, base, device)
224
+
225
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
226
+ self.max_seq_len_cached = seq_len
227
+
228
+ if seq_len > self.max_position_embeddings:
229
+ base = self.base * ((self.scaling_factor * seq_len /
230
+ self.max_position_embeddings) -
231
+ (self.scaling_factor - 1))**(
232
+ self.dim / (self.dim - 2))
233
+ inv_freq = 1.0 / (
234
+ base
235
+ **(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
236
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
237
+
238
+ t = torch.arange(
239
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
240
+
241
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
242
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
243
+ emb = torch.cat((freqs, freqs), dim=-1)
244
+ self.register_buffer(
245
+ 'cos_cached', emb.cos().to(dtype), persistent=False)
246
+ self.register_buffer(
247
+ 'sin_cached', emb.sin().to(dtype), persistent=False)
248
+
249
+
250
+ def rotate_half(x):
251
+ """Rotates half the hidden dims of the input."""
252
+ x1 = x[..., :x.shape[-1] // 2]
253
+ x2 = x[..., x.shape[-1] // 2:]
254
+ return torch.cat((-x2, x1), dim=-1)
255
+
256
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
257
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
258
+ """Applies Rotary Position Embedding to the query and key tensors."""
259
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
260
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
261
+ q_embed = (q * cos) + (rotate_half(q) * sin)
262
+ k_embed = (k * cos) + (rotate_half(k) * sin)
263
+ return q_embed, k_embed
264
+
265
+
266
+ class InternLM2MLP(nn.Module):
267
+
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ self.config = config
271
+ self.hidden_size = config.hidden_size
272
+ self.intermediate_size = config.intermediate_size
273
+
274
+ self.w1 = LoRA(
275
+ self.hidden_size,
276
+ self.intermediate_size,
277
+ bias=False,
278
+ lora_r=256,
279
+ lora_alpha=256,
280
+ lora_len=576)
281
+ self.w3 = LoRA(
282
+ self.hidden_size,
283
+ self.intermediate_size,
284
+ bias=False,
285
+ lora_r=256,
286
+ lora_alpha=256,
287
+ lora_len=576)
288
+ self.w2 = LoRA(
289
+ self.intermediate_size,
290
+ self.hidden_size,
291
+ bias=False,
292
+ lora_r=256,
293
+ lora_alpha=256,
294
+ lora_len=576)
295
+
296
+ self.act_fn = ACT2FN[config.hidden_act]
297
+
298
+ def forward(self, x, im_mask):
299
+ down_proj = self.w2(
300
+ self.act_fn(self.w1(x, im_mask)) * self.w3(x, im_mask), im_mask)
301
+
302
+ return down_proj
303
+
304
+
305
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
306
+ """This is the equivalent of torch.repeat_interleave(x, dim=1,
307
+ repeats=n_rep).
308
+
309
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
310
+ (batch, num_attention_heads, seqlen, head_dim)
311
+ """
312
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
313
+ if n_rep == 1:
314
+ return hidden_states
315
+ hidden_states = hidden_states[:, :,
316
+ None, :, :].expand(batch,
317
+ num_key_value_heads,
318
+ n_rep, slen, head_dim)
319
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
320
+ head_dim)
321
+
322
+
323
+ class InternLM2Attention(nn.Module):
324
+ """Multi-headed attention from 'Attention Is All You Need' paper."""
325
+
326
+ def __init__(self, config: InternLM2Config):
327
+ super().__init__()
328
+ self.config = config
329
+ self.hidden_size = config.hidden_size
330
+ self.num_heads = config.num_attention_heads
331
+ self.head_dim = self.hidden_size // self.num_heads
332
+ self.num_key_value_heads = config.num_key_value_heads
333
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
334
+ self.max_position_embeddings = config.max_position_embeddings
335
+ self.is_causal = True
336
+
337
+ if (self.head_dim * self.num_heads) != self.hidden_size:
338
+ raise ValueError(
339
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
340
+ f' and `num_heads`: {self.num_heads}).')
341
+
342
+ self.wqkv = LoRA(
343
+ self.hidden_size,
344
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
345
+ bias=config.bias,
346
+ lora_r=256,
347
+ lora_alpha=256,
348
+ lora_len=576)
349
+
350
+ self.wo = LoRA(
351
+ self.num_heads * self.head_dim,
352
+ self.hidden_size,
353
+ bias=config.bias,
354
+ lora_r=256,
355
+ lora_alpha=256,
356
+ lora_len=576)
357
+ self._init_rope()
358
+
359
+ def _init_rope(self):
360
+ if self.config.rope_scaling is None:
361
+ self.rotary_emb = InternLM2RotaryEmbedding(
362
+ self.head_dim,
363
+ max_position_embeddings=self.max_position_embeddings,
364
+ base=self.config.rope_theta,
365
+ )
366
+ else:
367
+ scaling_type = self.config.rope_scaling['type']
368
+ scaling_factor = self.config.rope_scaling['factor']
369
+ if scaling_type == 'dynamic':
370
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
371
+ self.head_dim,
372
+ max_position_embeddings=self.max_position_embeddings,
373
+ base=self.config.rope_theta,
374
+ scaling_factor=scaling_factor)
375
+ else:
376
+ raise ValueError(
377
+ "Currently we only support rotary embedding's type being 'dynamic'."
378
+ )
379
+ return self.rotary_emb
380
+
381
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
382
+ return tensor.view(bsz, seq_len, self.num_heads,
383
+ self.head_dim).transpose(1, 2).contiguous()
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ position_ids: Optional[torch.LongTensor] = None,
390
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
391
+ output_attentions: bool = False,
392
+ use_cache: bool = False,
393
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
394
+ **kwargs,
395
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
396
+ Optional[Tuple[torch.Tensor]]]:
397
+ if 'padding_mask' in kwargs:
398
+ warnings.warn(
399
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
400
+ 'Please make sure use `attention_mask` instead.`')
401
+
402
+ bsz, q_len, _ = hidden_states.size()
403
+
404
+ qkv_states = self.wqkv(hidden_states, im_mask)
405
+
406
+ qkv_states = rearrange(
407
+ qkv_states,
408
+ 'b q (h gs d) -> b q h gs d',
409
+ gs=2 + self.num_key_value_groups,
410
+ d=self.head_dim,
411
+ )
412
+
413
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
414
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
415
+ key_states = qkv_states[..., -2, :]
416
+ value_states = qkv_states[..., -1, :]
417
+
418
+ query_states = query_states.transpose(1, 2)
419
+ key_states = key_states.transpose(1, 2)
420
+ value_states = value_states.transpose(1, 2)
421
+
422
+ kv_seq_len = key_states.shape[-2]
423
+ if past_key_value is not None:
424
+ kv_seq_len += past_key_value[0].shape[-2]
425
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
426
+ query_states, key_states = apply_rotary_pos_emb(
427
+ query_states, key_states, cos, sin, position_ids)
428
+
429
+ if past_key_value is not None:
430
+ # reuse k, v, self_attention
431
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
432
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
433
+
434
+ past_key_value = (key_states, value_states) if use_cache else None
435
+
436
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
437
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
438
+
439
+ attn_weights = torch.matmul(query_states, key_states.transpose(
440
+ 2, 3)) / math.sqrt(self.head_dim)
441
+
442
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
443
+ raise ValueError(
444
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
445
+ f' {attn_weights.size()}')
446
+
447
+ if attention_mask is not None:
448
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
449
+ raise ValueError(
450
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
451
+ )
452
+ attn_weights = attn_weights + attention_mask
453
+
454
+ # upcast attention to fp32
455
+ attn_weights = nn.functional.softmax(
456
+ attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
457
+ attn_output = torch.matmul(attn_weights, value_states)
458
+
459
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
460
+ raise ValueError(
461
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
462
+ f' {attn_output.size()}')
463
+
464
+ attn_output = attn_output.transpose(1, 2).contiguous()
465
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
466
+
467
+ attn_output = self.wo(attn_output, im_mask)
468
+
469
+ if not output_attentions:
470
+ attn_weights = None
471
+
472
+ return attn_output, attn_weights, past_key_value
473
+
474
+
475
+ class InternLM2FlashAttention2(InternLM2Attention):
476
+ """InternLM2 flash attention module.
477
+
478
+ This module inherits from `InternLM2Attention` as the weights of the module
479
+ stays untouched. The only required change would be on the forward pass
480
+ where it needs to correctly call the public API of flash attention and deal
481
+ with padding tokens in case the input contains any of them.
482
+ """
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.Tensor,
487
+ attention_mask: Optional[torch.LongTensor] = None,
488
+ position_ids: Optional[torch.LongTensor] = None,
489
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
490
+ output_attentions: bool = False,
491
+ use_cache: bool = False,
492
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
493
+ **kwargs,
494
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
495
+ Optional[Tuple[torch.Tensor]]]:
496
+ # InternLM2FlashAttention2 attention does not support output_attentions
497
+ if 'padding_mask' in kwargs:
498
+ warnings.warn(
499
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
500
+ 'Please make sure use `attention_mask` instead.`')
501
+
502
+ # overwrite attention_mask with padding_mask
503
+ attention_mask = kwargs.pop('padding_mask')
504
+
505
+ output_attentions = False
506
+
507
+ bsz, q_len, _ = hidden_states.size()
508
+
509
+ qkv_states = self.wqkv(hidden_states, im_mask)
510
+
511
+ qkv_states = rearrange(
512
+ qkv_states,
513
+ 'b q (h gs d) -> b q h gs d',
514
+ gs=2 + self.num_key_value_groups,
515
+ d=self.head_dim,
516
+ q=q_len,
517
+ )
518
+
519
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
520
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
521
+ key_states = qkv_states[..., -2, :]
522
+ value_states = qkv_states[..., -1, :]
523
+ query_states = query_states.transpose(1, 2)
524
+ key_states = key_states.transpose(1, 2)
525
+ value_states = value_states.transpose(1, 2)
526
+
527
+ kv_seq_len = key_states.shape[-2]
528
+ if past_key_value is not None:
529
+ kv_seq_len += past_key_value[0].shape[-2]
530
+
531
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
532
+
533
+ query_states, key_states = apply_rotary_pos_emb(
534
+ query_states, key_states, cos, sin, position_ids)
535
+
536
+ if past_key_value is not None:
537
+ # reuse k, v, self_attention
538
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
539
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
540
+
541
+ past_key_value = (key_states, value_states) if use_cache else None
542
+
543
+ query_states = query_states.transpose(1, 2)
544
+ key_states = key_states.transpose(1, 2)
545
+ value_states = value_states.transpose(1, 2)
546
+
547
+ attn_output = self._flash_attention_forward(
548
+ query_states,
549
+ key_states,
550
+ value_states,
551
+ attention_mask,
552
+ q_len)
553
+
554
+ attn_output = attn_output.reshape(bsz, q_len,
555
+ self.hidden_size).contiguous()
556
+ attn_output = self.wo(attn_output, im_mask)
557
+
558
+ if not output_attentions:
559
+ attn_weights = None
560
+
561
+ return attn_output, attn_weights, past_key_value
562
+
563
+ def _flash_attention_forward(
564
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
565
+ ):
566
+ """
567
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
568
+ first unpad the input, then computes the attention scores and pad the final attention scores.
569
+ Args:
570
+ query_states (`torch.Tensor`):
571
+ Input query states to be passed to Flash Attention API
572
+ key_states (`torch.Tensor`):
573
+ Input key states to be passed to Flash Attention API
574
+ value_states (`torch.Tensor`):
575
+ Input value states to be passed to Flash Attention API
576
+ attention_mask (`torch.Tensor`):
577
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
578
+ position of padding tokens and 1 for the position of non-padding tokens.
579
+ dropout (`int`, *optional*):
580
+ Attention dropout
581
+ softmax_scale (`float`, *optional*):
582
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
583
+ """
584
+ # Contains at least one padding token in the sequence
585
+ causal = self.is_causal and query_length != 1
586
+ if attention_mask is not None:
587
+ batch_size = query_states.shape[0]
588
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
589
+ query_states, key_states, value_states, attention_mask, query_length
590
+ )
591
+
592
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
593
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
594
+
595
+ attn_output_unpad = flash_attn_varlen_func(
596
+ query_states,
597
+ key_states,
598
+ value_states,
599
+ cu_seqlens_q=cu_seqlens_q,
600
+ cu_seqlens_k=cu_seqlens_k,
601
+ max_seqlen_q=max_seqlen_in_batch_q,
602
+ max_seqlen_k=max_seqlen_in_batch_k,
603
+ dropout_p=dropout,
604
+ softmax_scale=softmax_scale,
605
+ causal=causal,
606
+ )
607
+
608
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
609
+ else:
610
+ attn_output = flash_attn_func(
611
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
612
+ )
613
+
614
+ return attn_output
615
+
616
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
617
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
618
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
619
+
620
+ key_layer = index_first_axis(
621
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
622
+ )
623
+ value_layer = index_first_axis(
624
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
625
+ )
626
+
627
+ if query_length == kv_seq_len:
628
+ query_layer = index_first_axis(
629
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
630
+ )
631
+ cu_seqlens_q = cu_seqlens_k
632
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
633
+ indices_q = indices_k
634
+ elif query_length == 1:
635
+ max_seqlen_in_batch_q = 1
636
+ cu_seqlens_q = torch.arange(
637
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
638
+ ) # There is a memcpy here, that is very bad.
639
+ indices_q = cu_seqlens_q[:-1]
640
+ query_layer = query_layer.squeeze(1)
641
+ else:
642
+ # The -q_len: slice assumes left padding.
643
+ attention_mask = attention_mask[:, -query_length:]
644
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
645
+
646
+ return (
647
+ query_layer,
648
+ key_layer,
649
+ value_layer,
650
+ indices_q.to(torch.int64),
651
+ (cu_seqlens_q, cu_seqlens_k),
652
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
653
+ )
654
+
655
+ class InternLM2DecoderLayer(nn.Module):
656
+
657
+ def __init__(self, config: InternLM2Config):
658
+ super().__init__()
659
+ self.hidden_size = config.hidden_size
660
+ self.attention = (
661
+ InternLM2Attention(config=config)
662
+ if not getattr(config, 'attn_implementation')=="flash_attention_2" else
663
+ InternLM2FlashAttention2(config=config))
664
+ self.feed_forward = InternLM2MLP(config)
665
+ self.attention_norm = InternLM2RMSNorm(
666
+ config.hidden_size, eps=config.rms_norm_eps)
667
+ self.ffn_norm = InternLM2RMSNorm(
668
+ config.hidden_size, eps=config.rms_norm_eps)
669
+
670
+ def forward(
671
+ self,
672
+ hidden_states: torch.Tensor,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
676
+ output_attentions: Optional[bool] = False,
677
+ use_cache: Optional[bool] = False,
678
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
679
+ **kwargs,
680
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
681
+ torch.FloatTensor]]]:
682
+ """
683
+ Args:
684
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
685
+ attention_mask (`torch.FloatTensor`, *optional*):
686
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
687
+ query_sequence_length, key_sequence_length)` if default attention is used.
688
+ output_attentions (`bool`, *optional*):
689
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
690
+ returned tensors for more detail.
691
+ use_cache (`bool`, *optional*):
692
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
693
+ (see `past_key_values`).
694
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
695
+ """
696
+ if 'padding_mask' in kwargs:
697
+ warnings.warn(
698
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
699
+ 'Please make sure use `attention_mask` instead.`')
700
+
701
+ residual = hidden_states
702
+
703
+ hidden_states = self.attention_norm(hidden_states)
704
+
705
+ # Self Attention
706
+ hidden_states, self_attn_weights, present_key_value = self.attention(
707
+ hidden_states=hidden_states,
708
+ attention_mask=attention_mask,
709
+ position_ids=position_ids,
710
+ past_key_value=past_key_value,
711
+ output_attentions=output_attentions,
712
+ use_cache=use_cache,
713
+ im_mask=im_mask,
714
+ **kwargs,
715
+ )
716
+ hidden_states = residual + hidden_states
717
+
718
+ # Fully Connected
719
+ residual = hidden_states
720
+ hidden_states = self.ffn_norm(hidden_states)
721
+ hidden_states = self.feed_forward(hidden_states, im_mask)
722
+ hidden_states = residual + hidden_states
723
+
724
+ outputs = (hidden_states, )
725
+
726
+ if output_attentions:
727
+ outputs += (self_attn_weights, )
728
+
729
+ if use_cache:
730
+ outputs += (present_key_value, )
731
+
732
+ return outputs
733
+
734
+
735
+ InternLM2_START_DOCSTRING = r"""
736
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
737
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
738
+ etc.)
739
+
740
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
741
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
742
+ and behavior.
743
+
744
+ Parameters:
745
+ config ([`InternLM2Config`]):
746
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
747
+ load the weights associated with the model, only the configuration. Check out the
748
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
749
+ """
750
+
751
+
752
+ @add_start_docstrings(
753
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
754
+ InternLM2_START_DOCSTRING,
755
+ )
756
+ class InternLM2PreTrainedModel(PreTrainedModel):
757
+ config_class = InternLM2Config
758
+ base_model_prefix = 'model'
759
+ supports_gradient_checkpointing = True
760
+ _no_split_modules = ['InternLM2DecoderLayer']
761
+ _skip_keys_device_placement = 'past_key_values'
762
+
763
+ def _init_weights(self, module):
764
+ std = self.config.initializer_range
765
+ if isinstance(module, nn.Linear):
766
+ module.weight.data.normal_(mean=0.0, std=std)
767
+ if module.bias is not None:
768
+ module.bias.data.zero_()
769
+ elif isinstance(module, nn.Embedding):
770
+ module.weight.data.normal_(mean=0.0, std=std)
771
+ if module.padding_idx is not None:
772
+ module.weight.data[module.padding_idx].zero_()
773
+
774
+
775
+ InternLM2_INPUTS_DOCSTRING = r"""
776
+ Args:
777
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
778
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
779
+ it.
780
+
781
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
782
+ [`PreTrainedTokenizer.__call__`] for details.
783
+
784
+ [What are input IDs?](../glossary#input-ids)
785
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
786
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
787
+
788
+ - 1 for tokens that are **not masked**,
789
+ - 0 for tokens that are **masked**.
790
+
791
+ [What are attention masks?](../glossary#attention-mask)
792
+
793
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
794
+ [`PreTrainedTokenizer.__call__`] for details.
795
+
796
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
797
+ `past_key_values`).
798
+
799
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
800
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
801
+ information on the default strategy.
802
+
803
+ - 1 indicates the head is **not masked**,
804
+ - 0 indicates the head is **masked**.
805
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
806
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
807
+ config.n_positions - 1]`.
808
+
809
+ [What are position IDs?](../glossary#position-ids)
810
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
811
+ when `config.use_cache=True`):
812
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
813
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
814
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
815
+
816
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
817
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
818
+
819
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
820
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
821
+ of shape `(batch_size, sequence_length)`.
822
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
823
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
824
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
825
+ model's internal embedding lookup matrix.
826
+ use_cache (`bool`, *optional*):
827
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
828
+ `past_key_values`).
829
+ output_attentions (`bool`, *optional*):
830
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
831
+ tensors for more detail.
832
+ output_hidden_states (`bool`, *optional*):
833
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
834
+ more detail.
835
+ return_dict (`bool`, *optional*):
836
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
837
+ """
838
+
839
+
840
+ @add_start_docstrings(
841
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
842
+ InternLM2_START_DOCSTRING,
843
+ )
844
+ class InternLM2Model(InternLM2PreTrainedModel):
845
+ """Transformer decoder consisting of *config.num_hidden_layers* layers.
846
+ Each layer is a [`InternLM2DecoderLayer`]
847
+
848
+ Args:
849
+ config: InternLM2Config
850
+ """
851
+
852
+ _auto_class = 'AutoModel'
853
+
854
+ def __init__(self, config: InternLM2Config):
855
+ super().__init__(config)
856
+ self.padding_idx = config.pad_token_id
857
+ self.vocab_size = config.vocab_size
858
+ self.config = config
859
+
860
+ self.tok_embeddings = nn.Embedding(config.vocab_size,
861
+ config.hidden_size,
862
+ self.padding_idx)
863
+ self.layers = nn.ModuleList([
864
+ InternLM2DecoderLayer(config)
865
+ for _ in range(config.num_hidden_layers)
866
+ ])
867
+ self.norm = InternLM2RMSNorm(
868
+ config.hidden_size, eps=config.rms_norm_eps)
869
+
870
+ self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
871
+ self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
872
+
873
+ self.gradient_checkpointing = False
874
+ # Initialize weights and apply final processing
875
+ self.post_init()
876
+
877
+ def get_input_embeddings(self):
878
+ return self.tok_embeddings
879
+
880
+ def set_input_embeddings(self, value):
881
+ self.tok_embeddings = value
882
+
883
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
884
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
885
+ inputs_embeds, past_key_values_length):
886
+ # create causal mask
887
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
888
+ combined_attention_mask = None
889
+ if input_shape[-1] > 1:
890
+ combined_attention_mask = _make_causal_mask(
891
+ input_shape,
892
+ inputs_embeds.dtype,
893
+ device=inputs_embeds.device,
894
+ past_key_values_length=past_key_values_length,
895
+ )
896
+
897
+ if attention_mask is not None:
898
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
899
+ expanded_attn_mask = _expand_mask(
900
+ attention_mask, inputs_embeds.dtype,
901
+ tgt_len=input_shape[-1]).to(inputs_embeds.device)
902
+ combined_attention_mask = (
903
+ expanded_attn_mask if combined_attention_mask is None else
904
+ expanded_attn_mask + combined_attention_mask)
905
+
906
+ return combined_attention_mask
907
+
908
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
909
+ def forward(self,
910
+ input_ids: torch.LongTensor = None,
911
+ attention_mask: Optional[torch.Tensor] = None,
912
+ position_ids: Optional[torch.LongTensor] = None,
913
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
914
+ inputs_embeds: Optional[torch.FloatTensor] = None,
915
+ use_cache: Optional[bool] = None,
916
+ output_attentions: Optional[bool] = None,
917
+ output_hidden_states: Optional[bool] = None,
918
+ return_dict: Optional[bool] = None,
919
+ **kwargs) -> Union[Tuple, BaseModelOutputWithPast]:
920
+
921
+ im_mask = kwargs.get('im_mask', None)
922
+
923
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
924
+ output_hidden_states = (
925
+ output_hidden_states if output_hidden_states is not None else
926
+ self.config.output_hidden_states)
927
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
928
+
929
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
930
+
931
+ if self.config.attn_implementation: _import_flash_attn()
932
+
933
+ # retrieve input_ids and inputs_embeds
934
+ if input_ids is not None and inputs_embeds is not None:
935
+ raise ValueError(
936
+ 'You cannot specify both input_ids and inputs_embeds at the same time'
937
+ )
938
+ elif input_ids is not None:
939
+ batch_size, seq_length = input_ids.shape[:2]
940
+ elif inputs_embeds is not None:
941
+ batch_size, seq_length = inputs_embeds.shape[:2]
942
+ else:
943
+ raise ValueError(
944
+ 'You have to specify either input_ids or inputs_embeds')
945
+
946
+ seq_length_with_past = seq_length
947
+ past_key_values_length = 0
948
+ if past_key_values is not None:
949
+ past_key_values_length = past_key_values[0][0].shape[2]
950
+ seq_length_with_past = seq_length_with_past + past_key_values_length
951
+
952
+ if position_ids is None:
953
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
954
+ position_ids = torch.arange(
955
+ past_key_values_length,
956
+ seq_length + past_key_values_length,
957
+ dtype=torch.long,
958
+ device=device)
959
+ position_ids = position_ids.unsqueeze(0)
960
+
961
+ if inputs_embeds is None:
962
+ inputs_embeds = self.tok_embeddings(input_ids)
963
+ im_mask = torch.zeros(inputs_embeds.shape[:2]).to(
964
+ inputs_embeds.device).bool()
965
+ if self.config.attn_implementation == "flash_attention_2":
966
+ # 2d mask is passed through the layers
967
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
968
+ else:
969
+ if attention_mask is None:
970
+ attention_mask = torch.ones(
971
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
972
+ )
973
+ attention_mask = self._prepare_decoder_attention_mask(
974
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
975
+ )
976
+
977
+ # embed positions
978
+ hidden_states = inputs_embeds
979
+
980
+ if self.gradient_checkpointing and self.training:
981
+ if use_cache:
982
+ logger.warning_once(
983
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
984
+ )
985
+ use_cache = False
986
+
987
+ # decoder layers
988
+ all_hidden_states = () if output_hidden_states else None
989
+ all_self_attns = () if output_attentions else None
990
+ next_decoder_cache = () if use_cache else None
991
+
992
+ for idx, decoder_layer in enumerate(self.layers):
993
+ if output_hidden_states:
994
+ all_hidden_states += (hidden_states, )
995
+
996
+ past_key_value = past_key_values[
997
+ idx] if past_key_values is not None else None
998
+
999
+ if self.gradient_checkpointing and self.training:
1000
+
1001
+ def create_custom_forward(module):
1002
+
1003
+ def custom_forward(*inputs):
1004
+ # None for past_key_value
1005
+ return module(*inputs, output_attentions, None,
1006
+ im_mask)
1007
+
1008
+ return custom_forward
1009
+
1010
+ # TroL reusing
1011
+ original_hidden_states_list = []
1012
+ for _ in range(2):
1013
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1014
+ create_custom_forward(decoder_layer),
1015
+ hidden_states,
1016
+ attention_mask,
1017
+ position_ids,
1018
+ None,
1019
+ )
1020
+ hidden_states = layer_outputs[0]
1021
+ original_hidden_states_list.append(layer_outputs[0])
1022
+ # Second TroL Gating & Feature Merging
1023
+ trol_score = self.trol_function(original_hidden_states_list[0], idx)
1024
+ updated_hidden_states = original_hidden_states_list[0] * (1 - trol_score) + original_hidden_states_list[1] * trol_score
1025
+
1026
+ else:
1027
+ if hidden_states.shape[1] > 1:
1028
+ # TroL reusing
1029
+ original_hidden_states_list = []
1030
+ original_past_key_value_list = []
1031
+ for _ in range(2):
1032
+ layer_outputs = decoder_layer(
1033
+ hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ position_ids=position_ids,
1036
+ past_key_value=past_key_value,
1037
+ output_attentions=output_attentions,
1038
+ use_cache=use_cache,
1039
+ im_mask=im_mask,
1040
+ )
1041
+ hidden_states = layer_outputs[0]
1042
+ original_hidden_states_list.append(layer_outputs[0])
1043
+ original_past_key_value_list.append(layer_outputs[1])
1044
+ # Second TroL Gating & Feature Merging
1045
+ trol_score = self.trol_function(original_hidden_states_list[0], idx)
1046
+ updated_hidden_states = original_hidden_states_list[0] * (1-trol_score) + original_hidden_states_list[1] * trol_score
1047
+ updated_past_key = original_past_key_value_list[0][0]
1048
+ updated_past_value = original_past_key_value_list[0][1]
1049
+ else:
1050
+ # TroL reusing
1051
+ layer_outputs = decoder_layer(
1052
+ hidden_states,
1053
+ attention_mask=attention_mask,
1054
+ position_ids=position_ids,
1055
+ past_key_value=past_key_value,
1056
+ output_attentions=output_attentions,
1057
+ use_cache=use_cache,
1058
+ im_mask=im_mask,
1059
+ )
1060
+ updated_hidden_states = layer_outputs[0]
1061
+ updated_past_key = layer_outputs[1][0]
1062
+ updated_past_value = layer_outputs[1][1]
1063
+
1064
+ # hidden_states = layer_outputs[0] -> updated hidden states
1065
+ hidden_states = updated_hidden_states
1066
+
1067
+ if use_cache:
1068
+ next_decoder_cache += (
1069
+ (updated_past_key, updated_past_value), ) # updated past key values
1070
+
1071
+ if output_attentions:
1072
+ all_self_attns += (layer_outputs[1], )
1073
+
1074
+ hidden_states = self.norm(hidden_states)
1075
+
1076
+ # add hidden states from the last decoder layer
1077
+ if output_hidden_states:
1078
+ all_hidden_states += (hidden_states, )
1079
+
1080
+ next_cache = next_decoder_cache if use_cache else None
1081
+ if not return_dict:
1082
+ return tuple(
1083
+ v for v in
1084
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
1085
+ if v is not None)
1086
+ return BaseModelOutputWithPast(
1087
+ last_hidden_state=hidden_states,
1088
+ past_key_values=next_cache,
1089
+ hidden_states=all_hidden_states,
1090
+ attentions=all_self_attns,
1091
+ )
trol/arch_internlm2/modeling_trol.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # System
2
+ import torch
3
+ from torch import nn
4
+ from utils.utils import *
5
+ import torch.utils.checkpoint
6
+ from typing import List, Optional, Tuple, Union
7
+ from .build_module import build_vision_projector, build_vision_tower
8
+ from .modeling_internlm2 import InternLM2Model, InternLM2PreTrainedModel
9
+
10
+ # Dataclass & ModelOutput
11
+ from dataclasses import dataclass
12
+ from transformers.modeling_outputs import ModelOutput
13
+ @dataclass
14
+ class TroLCausalLMOutputWithPast(ModelOutput):
15
+ loss: Optional[torch.FloatTensor] = None
16
+ logits: torch.FloatTensor = None
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None
18
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
19
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
20
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
21
+
22
+ class TroLForCausalLM(InternLM2PreTrainedModel):
23
+ _auto_class = 'AutoModelForCausalLM'
24
+
25
+ _tied_weights_keys = ['output.weight']
26
+
27
+ def __init__(self, config):
28
+ super().__init__(config)
29
+
30
+ # Model
31
+ self.model = InternLM2Model(config)
32
+ self.vocab_size = config.vocab_size
33
+ self.output = nn.Linear(config.hidden_size, config.vocab_size-1, bias=False)
34
+ self.max_length = config.max_length
35
+
36
+ # Initialize weights and apply final processing
37
+ self.post_init()
38
+
39
+ # Vision Encoder
40
+ self.vit = build_vision_tower()
41
+
42
+ # Vision Projection
43
+ self.vision_proj = build_vision_projector(self.config.hidden_size)
44
+
45
+ # image processing variable
46
+ self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1,-1,1,1) * 255
47
+ self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1,-1,1,1) * 255
48
+
49
+ # prompt rule
50
+ self.prompt_rule = {"system_start": "<s>[UNUSED_TOKEN_146]system\n",
51
+ "system_end": "[UNUSED_TOKEN_145]",
52
+ "user_start": "[UNUSED_TOKEN_146]user\n",
53
+ "user_end": "[UNUSED_TOKEN_145]",
54
+ "assistant_start": "[UNUSED_TOKEN_146]assistant\n",
55
+ "assistant_end": "[UNUSED_TOKEN_145]\n</s>",
56
+ "test_start": "assistant\n",
57
+ "test_end": "[UNUSED_TOKEN_145]",
58
+ "split": "\n",
59
+ }
60
+
61
+ def image_processor(self, images):
62
+ norm_images = (images - self.mean.to(images.device)) / self.std.to(images.device)
63
+ return norm_images
64
+
65
+ def eval_process(
66
+ self,
67
+ inputs,
68
+ data,
69
+ tokenizer,
70
+ device,
71
+ img_token_number,
72
+ ):
73
+ batched_image = []
74
+ batched_qa_prompt=[]
75
+ for _input in inputs:
76
+
77
+ # Visualization
78
+ # imim = _input['image'].cpu().permute(1, 2, 0)
79
+
80
+ # adding <image> to question if not included despite being an image, and adding system prompt and <tor> prompt
81
+ if 'image' in _input.keys() and not '<image>' in _input['question']: _input['question'] = '<image>\n' + _input['question']
82
+
83
+ # make question and answer
84
+ question = make_instruction(_input['question'], data, self.prompt_rule)
85
+
86
+ # add bundle image tokens if it has <image> token
87
+ question = add_bundle_tokens(question, '<image>', img_token_number)
88
+
89
+ batched_qa_prompt.append(question)
90
+
91
+ # making batched image prompt
92
+ if 'image' in _input.keys() and _input['image'] != None: batched_image.append(_input['image'].to(device))
93
+
94
+ '''For Final Outputs'''
95
+ qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False)
96
+
97
+ # [1] input_ids
98
+ input_ids = qa_prompts.input_ids.to(device)
99
+
100
+ # [2] attention_mask
101
+ attention_mask = qa_prompts.attention_mask.to(device)
102
+
103
+ # [3] im_mask
104
+ im_mask = torch.zeros_like(input_ids).bool()
105
+ im_mask[torch.where(input_ids==self.config.image_token_index)] = True
106
+
107
+ if len(batched_image):
108
+ return {"input_ids": input_ids,
109
+ "attention_mask": attention_mask,
110
+ "im_mask": im_mask,
111
+ "image_features": self.clip_features(self.image_processor(torch.stack(batched_image)).to(device))
112
+ }
113
+ else:
114
+ return {"input_ids": input_ids,
115
+ "attention_mask": attention_mask,
116
+ "im_mask": im_mask,
117
+ }
118
+
119
+ def clip_features(self, image):
120
+ self.vit.eval()
121
+ return self.vit(image)
122
+
123
+ def _merge_input_embeds_with_image_features(self, image_features, inputs_embeds, input_ids):
124
+
125
+ # batch index for image feature
126
+ batch_ind_image_feature = 0
127
+
128
+ # shape of image_features
129
+ _, C, D = image_features.shape
130
+
131
+ for ind, input_id in enumerate(input_ids):
132
+ matching = torch.where(input_id==self.config.image_token_index)
133
+ num_image_tokens_per_one_sample = len(matching[0]) // C
134
+ inputs_embeds[ind][matching] = image_features[batch_ind_image_feature: batch_ind_image_feature+num_image_tokens_per_one_sample].view(-1, D)
135
+ batch_ind_image_feature += num_image_tokens_per_one_sample
136
+
137
+ def forward(
138
+ self,
139
+ input_ids: torch.LongTensor = None,
140
+ image_features: torch.FloatTensor = None,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ im_mask: torch.BoolTensor = None,
143
+ position_ids: Optional[torch.LongTensor] = None,
144
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
145
+ inputs_embeds: Optional[torch.FloatTensor] = None,
146
+ labels: Optional[torch.LongTensor] = None,
147
+ use_cache: Optional[bool] = None,
148
+ output_attentions: Optional[bool] = None,
149
+ output_hidden_states: Optional[bool] = None,
150
+ return_dict: Optional[bool] = None,
151
+ ) -> Union[Tuple, TroLCausalLMOutputWithPast]:
152
+
153
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
154
+ output_hidden_states = (
155
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
156
+ )
157
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
158
+
159
+ if inputs_embeds is None:
160
+ # 1. Extra the input embeddings
161
+ inputs_embeds = self.get_input_embeddings()(input_ids)
162
+
163
+ # 2. Merge text and images
164
+ if image_features is not None and input_ids.shape[1] != 1:
165
+ image_features = self.vision_proj(image_features.to(inputs_embeds.dtype))
166
+ self._merge_input_embeds_with_image_features(image_features, inputs_embeds, input_ids)
167
+
168
+ # In case input_ids.shape[1] == 1 & image_features==None & past_key_values != None, we are in the case of
169
+ # generation with cache
170
+ elif past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
171
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
172
+ # that are set to 0
173
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
174
+
175
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
176
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
177
+
178
+ # Get the target length
179
+ target_length = input_ids.shape[1]
180
+ past_length = first_layer_past_key_value.shape[-1]
181
+
182
+ extended_attention_mask = torch.ones(
183
+ (attention_mask.shape[0], past_length),
184
+ dtype=attention_mask.dtype,
185
+ device=attention_mask.device,
186
+ )
187
+
188
+ # Filter out only the tokens that can be un-attended, this can happen
189
+ # if one uses Llava + Fused modules where the cache on the
190
+ # first iteration is already big enough, or if one passes custom cache
191
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
192
+ new_batch_index = batch_index[valid_indices]
193
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
194
+
195
+ # Zero-out the places where we don't need to attend
196
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
197
+
198
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
199
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
200
+ im_mask = torch.zeros(inputs_embeds.shape[:2]).bool().to(inputs_embeds.device)
201
+
202
+ outputs = self.model(
203
+ attention_mask=attention_mask,
204
+ position_ids=position_ids,
205
+ past_key_values=past_key_values,
206
+ inputs_embeds=inputs_embeds,
207
+ use_cache=use_cache,
208
+ output_attentions=output_attentions,
209
+ output_hidden_states=output_hidden_states,
210
+ return_dict=return_dict,
211
+ im_mask=im_mask,
212
+ )
213
+
214
+ hidden_states = outputs[0]
215
+ logits = self.output(hidden_states)
216
+
217
+ loss = None
218
+ if labels is not None:
219
+ # Shift so that tokens < n predict n
220
+ if attention_mask is not None:
221
+ shift_attention_mask = attention_mask[..., 1:]
222
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
223
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
224
+ else:
225
+ shift_logits = logits[..., :-1, :].contiguous()
226
+ shift_labels = labels[..., 1:].contiguous()
227
+ # Flatten the tokens
228
+ loss_fct = nn.CrossEntropyLoss()
229
+ loss = loss_fct(
230
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
231
+ )
232
+
233
+ if not return_dict:
234
+ output = (logits,) + outputs[1:]
235
+ return (loss,) + output if loss is not None else output
236
+
237
+ return TroLCausalLMOutputWithPast(
238
+ loss=loss,
239
+ logits=logits,
240
+ past_key_values=outputs.past_key_values,
241
+ hidden_states=outputs.hidden_states,
242
+ attentions=outputs.attentions,
243
+ )
244
+
245
+ def prepare_inputs_for_generation(self,
246
+ input_ids,
247
+ past_key_values=None,
248
+ attention_mask=None,
249
+ inputs_embeds=None,
250
+ image_features=None,
251
+ im_mask=None,
252
+ **kwargs):
253
+ if past_key_values is not None:
254
+ past_length = past_key_values[0][0].shape[2]
255
+
256
+ # Some generation methods already pass only the last input ID
257
+ if input_ids.shape[1] > past_length:
258
+ remove_prefix_length = past_length
259
+ else:
260
+ # Default to old behavior: keep only final ID
261
+ remove_prefix_length = input_ids.shape[1] - 1
262
+
263
+ input_ids = input_ids[:, remove_prefix_length:]
264
+
265
+ position_ids = kwargs.get('position_ids', None)
266
+ if attention_mask is not None and position_ids is None:
267
+ # create position_ids on the fly for batch generation
268
+ position_ids = attention_mask.long().cumsum(-1) - 1
269
+ position_ids.masked_fill_(attention_mask == 0, 1)
270
+ if past_key_values:
271
+ position_ids = position_ids[:, -input_ids.shape[1]:]
272
+
273
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
274
+ if inputs_embeds is not None and past_key_values is None:
275
+ model_inputs = {"inputs_embeds": inputs_embeds}
276
+ else:
277
+ model_inputs = {"input_ids": input_ids}
278
+
279
+ model_inputs.update(
280
+ {
281
+ "position_ids": position_ids,
282
+ "past_key_values": past_key_values,
283
+ "use_cache": kwargs.get("use_cache"),
284
+ "attention_mask": attention_mask,
285
+ "image_features": image_features,
286
+ "im_mask": im_mask,
287
+ }
288
+ )
289
+ return model_inputs
290
+
291
+ @staticmethod
292
+ def _reorder_cache(past_key_values, beam_idx):
293
+ reordered_past = ()
294
+ for layer_past in past_key_values:
295
+ reordered_past += (tuple(
296
+ past_state.index_select(0, beam_idx.to(past_state.device))
297
+ for past_state in layer_past), )
298
+ return reordered_past
trol/arch_internlm2/tokenization_internlm2.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) InternLM. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """Tokenization classes for IntermLM."""
20
+ import os
21
+ from shutil import copyfile
22
+ from typing import Any, Dict, List, Optional, Tuple
23
+
24
+ import sentencepiece as spm
25
+ from transformers.tokenization_utils import PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
31
+
32
+ PRETRAINED_VOCAB_FILES_MAP = {}
33
+
34
+
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ Path to the vocabulary file.
41
+ """
42
+
43
+ vocab_files_names = VOCAB_FILES_NAMES
44
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45
+ model_input_names = ['input_ids', 'attention_mask']
46
+ _auto_class = 'AutoTokenizer'
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_file,
51
+ unk_token='<unk>',
52
+ bos_token='<s>',
53
+ eos_token='</s>',
54
+ pad_token='</s>',
55
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
56
+ add_bos_token=True,
57
+ add_eos_token=False,
58
+ decode_with_prefix_space=False,
59
+ clean_up_tokenization_spaces=False,
60
+ **kwargs,
61
+ ):
62
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
63
+ self.vocab_file = vocab_file
64
+ self.add_bos_token = add_bos_token
65
+ self.add_eos_token = add_eos_token
66
+ self.decode_with_prefix_space = decode_with_prefix_space
67
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
68
+ self.sp_model.Load(vocab_file)
69
+ self._no_prefix_space_tokens = None
70
+ super().__init__(
71
+ bos_token=bos_token,
72
+ eos_token=eos_token,
73
+ unk_token=unk_token,
74
+ pad_token=pad_token,
75
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
76
+ **kwargs,
77
+ )
78
+ """ Initialization"""
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {
85
+ i
86
+ for i, tok in enumerate(vocab) if not tok.startswith('▁')
87
+ }
88
+ return self._no_prefix_space_tokens
89
+
90
+ @property
91
+ def vocab_size(self):
92
+ """Returns vocab size."""
93
+ return self.sp_model.get_piece_size()
94
+
95
+ @property
96
+ def bos_token_id(self) -> Optional[int]:
97
+ return self.sp_model.bos_id()
98
+
99
+ @property
100
+ def eos_token_id(self) -> Optional[int]:
101
+ return self.sp_model.eos_id()
102
+
103
+ def get_vocab(self):
104
+ """Returns vocab as a dict."""
105
+ vocab = {
106
+ self.convert_ids_to_tokens(i): i
107
+ for i in range(self.vocab_size)
108
+ }
109
+ vocab.update(self.added_tokens_encoder)
110
+ return vocab
111
+
112
+ def _tokenize(self, text):
113
+ """Returns a tokenized string."""
114
+ return self.sp_model.encode(text, out_type=str)
115
+
116
+ def _convert_token_to_id(self, token):
117
+ """Converts a token (str) in an id using the vocab."""
118
+ return self.sp_model.piece_to_id(token)
119
+
120
+ def _convert_id_to_token(self, index):
121
+ """Converts an index (integer) in a token (str) using the vocab."""
122
+ token = self.sp_model.IdToPiece(index)
123
+ return token
124
+
125
+ def _maybe_add_prefix_space(self, tokens, decoded):
126
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
127
+ return ' ' + decoded
128
+ else:
129
+ return decoded
130
+
131
+ def convert_tokens_to_string(self, tokens):
132
+ """Converts a sequence of tokens (string) in a single string."""
133
+ current_sub_tokens = []
134
+ out_string = ''
135
+ prev_is_special = False
136
+ for token in tokens:
137
+ # make sure that special tokens are not decoded using sentencepiece model
138
+ if token in self.all_special_tokens:
139
+ if not prev_is_special:
140
+ out_string += ' '
141
+ out_string += self.sp_model.decode(current_sub_tokens) + token
142
+ prev_is_special = True
143
+ current_sub_tokens = []
144
+ else:
145
+ current_sub_tokens.append(token)
146
+ prev_is_special = False
147
+ out_string += self.sp_model.decode(current_sub_tokens)
148
+ out_string = self.clean_up_tokenization(out_string)
149
+ out_string = self._maybe_add_prefix_space(
150
+ tokens=tokens, decoded=out_string)
151
+ return out_string[1:]
152
+
153
+ def save_vocabulary(self,
154
+ save_directory,
155
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
156
+ """Save the vocabulary and special tokens file to a directory.
157
+
158
+ Args:
159
+ save_directory (`str`):
160
+ The directory in which to save the vocabulary.
161
+
162
+ Returns:
163
+ `Tuple(str)`: Paths to the files saved.
164
+ """
165
+ if not os.path.isdir(save_directory):
166
+ logger.error(
167
+ f'Vocabulary path ({save_directory}) should be a directory')
168
+ return
169
+ out_vocab_file = os.path.join(
170
+ save_directory,
171
+ (filename_prefix + '-' if filename_prefix else '') +
172
+ VOCAB_FILES_NAMES['vocab_file'])
173
+
174
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
175
+ out_vocab_file) and os.path.isfile(self.vocab_file):
176
+ copyfile(self.vocab_file, out_vocab_file)
177
+ elif not os.path.isfile(self.vocab_file):
178
+ with open(out_vocab_file, 'wb') as fi:
179
+ content_spiece_model = self.sp_model.serialized_model_proto()
180
+ fi.write(content_spiece_model)
181
+
182
+ return (out_vocab_file, )
183
+
184
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
185
+ if self.add_bos_token:
186
+ bos_token_ids = [self.bos_token_id]
187
+ else:
188
+ bos_token_ids = []
189
+
190
+ output = bos_token_ids + token_ids_0
191
+
192
+ if token_ids_1 is not None:
193
+ output = output + token_ids_1
194
+
195
+ if self.add_eos_token:
196
+ output = output + [self.eos_token_id]
197
+
198
+ return output
199
+
200
+ def get_special_tokens_mask(
201
+ self,
202
+ token_ids_0: List[int],
203
+ token_ids_1: Optional[List[int]] = None,
204
+ already_has_special_tokens: bool = False) -> List[int]:
205
+ """Retrieve sequence ids from a token list that has no special tokens
206
+ added. This method is called when adding special tokens using the
207
+ tokenizer `prepare_for_model` method.
208
+
209
+ Args:
210
+ token_ids_0 (`List[int]`):
211
+ List of IDs.
212
+ token_ids_1 (`List[int]`, *optional*):
213
+ Optional second list of IDs for sequence pairs.
214
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
215
+ Whether or not the token list is already formatted with special tokens for the model.
216
+
217
+ Returns:
218
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
219
+ """
220
+ if already_has_special_tokens:
221
+ return super().get_special_tokens_mask(
222
+ token_ids_0=token_ids_0,
223
+ token_ids_1=token_ids_1,
224
+ already_has_special_tokens=True)
225
+
226
+ if token_ids_1 is None:
227
+ return [1] + ([0] * len(token_ids_0)) + [1]
228
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + (
229
+ [0] * len(token_ids_1)) + [1]
230
+
231
+ def create_token_type_ids_from_sequences(
232
+ self,
233
+ token_ids_0: List[int],
234
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
235
+ """Create a mask from the two sequences passed to be used in a
236
+ sequence-pair classification task. T5 does not make use of token type
237
+ ids, therefore a list of zeros is returned.
238
+
239
+ Args:
240
+ token_ids_0 (`List[int]`):
241
+ List of IDs.
242
+ token_ids_1 (`List[int]`, *optional*):
243
+ Optional second list of IDs for sequence pairs.
244
+
245
+ Returns:
246
+ `List[int]`: List of zeros.
247
+ """
248
+ eos = [self.eos_token_id]
249
+
250
+ if token_ids_1 is None:
251
+ return len(token_ids_0 + eos) * [0]
252
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
trol/arch_phi3/configuration_intern_vit.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class InternVisionConfig(PretrainedConfig):
11
+ model_type = 'intern_vit_6b'
12
+
13
+ def __init__(
14
+ self,
15
+ num_channels=3,
16
+ patch_size=14,
17
+ image_size=224,
18
+ qkv_bias=False,
19
+ hidden_size=3200,
20
+ num_attention_heads=25,
21
+ intermediate_size=12800,
22
+ qk_normalization=True,
23
+ num_hidden_layers=48,
24
+ use_flash_attn=True,
25
+ hidden_act='gelu',
26
+ norm_type='rms_norm',
27
+ layer_norm_eps=1e-6,
28
+ dropout=0.0,
29
+ drop_path_rate=0.0,
30
+ attention_dropout=0.0,
31
+ initializer_range=0.02,
32
+ initializer_factor=0.1,
33
+ **kwargs,
34
+ ):
35
+ super().__init__(**kwargs)
36
+
37
+ self.hidden_size = hidden_size
38
+ self.intermediate_size = intermediate_size
39
+ self.dropout = dropout
40
+ self.drop_path_rate = drop_path_rate
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.num_channels = num_channels
44
+ self.patch_size = patch_size
45
+ self.image_size = image_size
46
+ self.initializer_range = initializer_range
47
+ self.initializer_factor = initializer_factor
48
+ self.attention_dropout = attention_dropout
49
+ self.layer_norm_eps = layer_norm_eps
50
+ self.hidden_act = hidden_act
51
+ self.norm_type = norm_type
52
+ self.qkv_bias = qkv_bias
53
+ self.qk_normalization = qk_normalization
54
+ self.use_flash_attn = use_flash_attn
55
+
56
+ @classmethod
57
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
58
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
59
+
60
+ if 'vision_config' in config_dict:
61
+ config_dict = config_dict['vision_config']
62
+
63
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
64
+ logger.warning(
65
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
66
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
67
+ )
68
+
69
+ return cls.from_dict(config_dict, **kwargs)
trol/arch_phi3/configuration_phi3.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
7
+ 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
8
+ 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
9
+ }
10
+
11
+ class Phi3Config(PretrainedConfig):
12
+ model_type = 'phi3'
13
+ keys_to_ignore_at_inference = ['past_key_values']
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size=32064,
18
+ hidden_size=3072,
19
+ intermediate_size=8192,
20
+ num_hidden_layers=32,
21
+ num_attention_heads=32,
22
+ num_key_value_heads=None,
23
+ resid_pdrop=0.0,
24
+ embd_pdrop=0.0,
25
+ attention_dropout=0.0,
26
+ hidden_act='silu',
27
+ max_position_embeddings=4096,
28
+ original_max_position_embeddings=4096,
29
+ initializer_range=0.02,
30
+ rms_norm_eps=1e-5,
31
+ use_cache=True,
32
+ tie_word_embeddings=False,
33
+ rope_theta=10000.0,
34
+ rope_scaling=None,
35
+ bos_token_id=1,
36
+ eos_token_id=32000,
37
+ pad_token_id=32000,
38
+ sliding_window=None,
39
+ **kwargs,
40
+ ):
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.intermediate_size = intermediate_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_attention_heads = num_attention_heads
46
+
47
+ if num_key_value_heads is None:
48
+ num_key_value_heads = num_attention_heads
49
+
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.resid_pdrop = resid_pdrop
52
+ self.embd_pdrop = embd_pdrop
53
+ self.attention_dropout = attention_dropout
54
+ self.hidden_act = hidden_act
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.original_max_position_embeddings = original_max_position_embeddings
57
+ self.initializer_range = initializer_range
58
+ self.rms_norm_eps = rms_norm_eps
59
+ self.use_cache = use_cache
60
+ self.rope_theta = rope_theta
61
+ self.rope_scaling = rope_scaling
62
+ self._rope_scaling_validation()
63
+ self.sliding_window = sliding_window
64
+
65
+ super().__init__(
66
+ bos_token_id=bos_token_id,
67
+ eos_token_id=eos_token_id,
68
+ pad_token_id=pad_token_id,
69
+ tie_word_embeddings=tie_word_embeddings,
70
+ **kwargs,
71
+ )
72
+
73
+ def _rope_scaling_validation(self):
74
+ """
75
+ Validate the `rope_scaling` configuration.
76
+ """
77
+ if self.rope_scaling is None:
78
+ return
79
+
80
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
81
+ raise ValueError(
82
+ '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
83
+ f'got {self.rope_scaling}'
84
+ )
85
+ rope_scaling_type = self.rope_scaling.get('type', None)
86
+ rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
87
+ rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
88
+ if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
89
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
90
+ if not (
91
+ isinstance(rope_scaling_short_factor, list)
92
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
93
+ ):
94
+ raise ValueError(
95
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
96
+ )
97
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
98
+ raise ValueError(
99
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
100
+ )
101
+ if not (
102
+ isinstance(rope_scaling_long_factor, list)
103
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
104
+ ):
105
+ raise ValueError(
106
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
107
+ )
108
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
109
+ raise ValueError(
110
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
111
+ )
trol/arch_phi3/modeling_intern_vit.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from einops import rearrange
7
+ from timm.models.layers import DropPath
8
+ from torch import nn
9
+ from transformers.activations import ACT2FN
10
+ from transformers.modeling_outputs import (BaseModelOutput,
11
+ BaseModelOutputWithPooling)
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_intern_vit import InternVisionConfig
16
+
17
+ try:
18
+ try: # v1
19
+ from flash_attn.flash_attn_interface import \
20
+ flash_attn_unpadded_qkvpacked_func
21
+ except: # v2
22
+ from flash_attn.flash_attn_interface import \
23
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
24
+
25
+ from flash_attn.bert_padding import pad_input, unpad_input
26
+
27
+ has_flash_attn = True
28
+ except:
29
+ print('FlashAttention is not installed.')
30
+ has_flash_attn = False
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class FlashAttention(nn.Module):
36
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
37
+ super().__init__()
38
+ self.softmax_scale = softmax_scale
39
+ self.dropout_p = attention_dropout
40
+
41
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
42
+ max_s=None, need_weights=False):
43
+ assert not need_weights
44
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
45
+ assert qkv.is_cuda
46
+
47
+ if cu_seqlens is None:
48
+ batch_size = qkv.shape[0]
49
+ seqlen = qkv.shape[1]
50
+ if key_padding_mask is None:
51
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
52
+ max_s = seqlen
53
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
54
+ device=qkv.device)
55
+ output = flash_attn_unpadded_qkvpacked_func(
56
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
57
+ softmax_scale=self.softmax_scale, causal=causal
58
+ )
59
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
60
+ else:
61
+ nheads = qkv.shape[-2]
62
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
63
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
64
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
65
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
66
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
67
+ softmax_scale=self.softmax_scale, causal=causal
68
+ )
69
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
70
+ indices, batch_size, seqlen),
71
+ 'b s (h d) -> b s h d', h=nheads)
72
+ else:
73
+ assert max_s is not None
74
+ output = flash_attn_unpadded_qkvpacked_func(
75
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
76
+ softmax_scale=self.softmax_scale, causal=causal
77
+ )
78
+
79
+ return output, None
80
+
81
+
82
+ class InternRMSNorm(nn.Module):
83
+ def __init__(self, hidden_size, eps=1e-6):
84
+ super().__init__()
85
+ self.weight = nn.Parameter(torch.ones(hidden_size))
86
+ self.variance_epsilon = eps
87
+
88
+ def forward(self, hidden_states):
89
+ input_dtype = hidden_states.dtype
90
+ hidden_states = hidden_states.to(torch.float32)
91
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
92
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
93
+ return self.weight * hidden_states.to(input_dtype)
94
+
95
+
96
+ # try:
97
+ # from apex.normalization import FusedRMSNorm
98
+
99
+ # InternRMSNorm = FusedRMSNorm # noqa
100
+
101
+ # logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
102
+ # except ImportError:
103
+ # # using the normal InternRMSNorm
104
+ # pass
105
+ # except Exception:
106
+ # logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
107
+ # pass
108
+
109
+
110
+ NORM2FN = {
111
+ 'rms_norm': InternRMSNorm,
112
+ 'layer_norm': nn.LayerNorm,
113
+ }
114
+
115
+
116
+ class InternVisionEmbeddings(nn.Module):
117
+ def __init__(self, config: InternVisionConfig):
118
+ super().__init__()
119
+ self.config = config
120
+ self.embed_dim = config.hidden_size
121
+ self.image_size = config.image_size
122
+ self.patch_size = config.patch_size
123
+
124
+ self.class_embedding = nn.Parameter(
125
+ torch.randn(1, 1, self.embed_dim),
126
+ )
127
+
128
+ self.patch_embedding = nn.Conv2d(
129
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
130
+ )
131
+
132
+ self.num_patches = (self.image_size // self.patch_size) ** 2
133
+ self.num_positions = self.num_patches + 1
134
+
135
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
136
+
137
+ def _get_pos_embed(self, pos_embed, H, W):
138
+ target_dtype = pos_embed.dtype
139
+ pos_embed = pos_embed.float().reshape(
140
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
141
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
142
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
143
+ return pos_embed
144
+
145
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
146
+ target_dtype = self.patch_embedding.weight.dtype
147
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
148
+ batch_size, _, height, width = patch_embeds.shape
149
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
150
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
151
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
152
+ position_embedding = torch.cat([
153
+ self.position_embedding[:, :1, :],
154
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
155
+ ], dim=1)
156
+ embeddings = embeddings + position_embedding.to(target_dtype)
157
+ return embeddings
158
+
159
+
160
+ class InternAttention(nn.Module):
161
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
162
+
163
+ def __init__(self, config: InternVisionConfig):
164
+ super().__init__()
165
+ self.config = config
166
+ self.embed_dim = config.hidden_size
167
+ self.num_heads = config.num_attention_heads
168
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
169
+ if config.use_flash_attn and not has_flash_attn:
170
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
171
+ self.head_dim = self.embed_dim // self.num_heads
172
+ if self.head_dim * self.num_heads != self.embed_dim:
173
+ raise ValueError(
174
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
175
+ f' {self.num_heads}).'
176
+ )
177
+
178
+ self.scale = self.head_dim ** -0.5
179
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
180
+ self.attn_drop = nn.Dropout(config.attention_dropout)
181
+ self.proj_drop = nn.Dropout(config.dropout)
182
+
183
+ self.qk_normalization = config.qk_normalization
184
+
185
+ if self.qk_normalization:
186
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
187
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
188
+
189
+ if self.use_flash_attn:
190
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
191
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+
193
+ def _naive_attn(self, x):
194
+ B, N, C = x.shape
195
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
196
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
197
+
198
+ if self.qk_normalization:
199
+ B_, H_, N_, D_ = q.shape
200
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
201
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
202
+
203
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
204
+ attn = attn.softmax(dim=-1)
205
+ attn = self.attn_drop(attn)
206
+
207
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
208
+ x = self.proj(x)
209
+ x = self.proj_drop(x)
210
+ return x
211
+
212
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
213
+ qkv = self.qkv(x)
214
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
215
+
216
+ if self.qk_normalization:
217
+ q, k, v = qkv.unbind(2)
218
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
219
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
220
+ qkv = torch.stack([q, k, v], dim=2)
221
+
222
+ context, _ = self.inner_attn(
223
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
224
+ )
225
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
226
+ outs = self.proj_drop(outs)
227
+ return outs
228
+
229
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
230
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
231
+ return x
232
+
233
+
234
+ class InternMLP(nn.Module):
235
+ def __init__(self, config: InternVisionConfig):
236
+ super().__init__()
237
+ self.config = config
238
+ self.act = ACT2FN[config.hidden_act]
239
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
240
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
241
+
242
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
243
+ hidden_states = self.fc1(hidden_states)
244
+ hidden_states = self.act(hidden_states)
245
+ hidden_states = self.fc2(hidden_states)
246
+ return hidden_states
247
+
248
+
249
+ class InternVisionEncoderLayer(nn.Module):
250
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
251
+ super().__init__()
252
+ self.embed_dim = config.hidden_size
253
+ self.intermediate_size = config.intermediate_size
254
+ self.norm_type = config.norm_type
255
+
256
+ self.attn = InternAttention(config)
257
+ self.mlp = InternMLP(config)
258
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
259
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
260
+
261
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
262
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
263
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
264
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
265
+
266
+ def forward(
267
+ self,
268
+ hidden_states: torch.Tensor,
269
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
270
+ """
271
+ Args:
272
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
273
+ """
274
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
275
+
276
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class InternVisionEncoder(nn.Module):
282
+ """
283
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
284
+ [`InternEncoderLayer`].
285
+
286
+ Args:
287
+ config (`InternConfig`):
288
+ The corresponding vision configuration for the `InternEncoder`.
289
+ """
290
+
291
+ def __init__(self, config: InternVisionConfig):
292
+ super().__init__()
293
+ self.config = config
294
+ # stochastic depth decay rule
295
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
296
+ self.layers = nn.ModuleList([
297
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
298
+ self.gradient_checkpointing = True
299
+
300
+ def forward(
301
+ self,
302
+ inputs_embeds,
303
+ output_hidden_states: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ ) -> Union[Tuple, BaseModelOutput]:
306
+ r"""
307
+ Args:
308
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
309
+ Embedded representation of the inputs. Should be float, not int tokens.
310
+ output_hidden_states (`bool`, *optional*):
311
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
312
+ for more detail.
313
+ return_dict (`bool`, *optional*):
314
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
315
+ """
316
+ output_hidden_states = (
317
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
318
+ )
319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+
321
+ encoder_states = () if output_hidden_states else None
322
+ hidden_states = inputs_embeds
323
+
324
+ for idx, encoder_layer in enumerate(self.layers):
325
+ if output_hidden_states:
326
+ encoder_states = encoder_states + (hidden_states,)
327
+ if self.gradient_checkpointing and self.training:
328
+ layer_outputs = torch.utils.checkpoint.checkpoint(
329
+ encoder_layer,
330
+ hidden_states)
331
+ else:
332
+ layer_outputs = encoder_layer(
333
+ hidden_states,
334
+ )
335
+ hidden_states = layer_outputs
336
+
337
+ if output_hidden_states:
338
+ encoder_states = encoder_states + (hidden_states,)
339
+
340
+ if not return_dict:
341
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
342
+ return BaseModelOutput(
343
+ last_hidden_state=hidden_states, hidden_states=encoder_states
344
+ )
345
+
346
+
347
+ class InternVisionModel(PreTrainedModel):
348
+ main_input_name = 'pixel_values'
349
+ config_class = InternVisionConfig
350
+ _no_split_modules = ['InternVisionEncoderLayer']
351
+
352
+ def __init__(self, config: InternVisionConfig):
353
+ super().__init__(config)
354
+ self.config = config
355
+
356
+ self.embeddings = InternVisionEmbeddings(config)
357
+ self.encoder = InternVisionEncoder(config)
358
+
359
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
360
+ pos_emb = self.embeddings.position_embedding
361
+ _, num_positions, embed_dim = pos_emb.shape
362
+ cls_emb = pos_emb[:, :1, :]
363
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
364
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
365
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
366
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
367
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
368
+ self.embeddings.image_size = new_size
369
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
370
+
371
+ def get_input_embeddings(self):
372
+ return self.embeddings
373
+
374
+ def forward(
375
+ self,
376
+ pixel_values: Optional[torch.FloatTensor] = None,
377
+ output_hidden_states: Optional[bool] = None,
378
+ return_dict: Optional[bool] = None,
379
+ pixel_embeds: Optional[torch.FloatTensor] = None,
380
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
381
+ output_hidden_states = (
382
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
383
+ )
384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
385
+
386
+ if pixel_values is None and pixel_embeds is None:
387
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
388
+
389
+ if pixel_embeds is not None:
390
+ hidden_states = pixel_embeds
391
+ else:
392
+ if len(pixel_values.shape) == 4:
393
+ hidden_states = self.embeddings(pixel_values)
394
+ else:
395
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
396
+ encoder_outputs = self.encoder(
397
+ inputs_embeds=hidden_states,
398
+ output_hidden_states=output_hidden_states,
399
+ return_dict=return_dict,
400
+ )
401
+ last_hidden_state = encoder_outputs.last_hidden_state
402
+ pooled_output = last_hidden_state[:, 0, :]
403
+
404
+ if not return_dict:
405
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
406
+
407
+ return BaseModelOutputWithPooling(
408
+ last_hidden_state=last_hidden_state,
409
+ pooler_output=pooled_output,
410
+ hidden_states=encoder_outputs.hidden_states,
411
+ attentions=encoder_outputs.attentions,
412
+ )
trol/arch_phi3/modeling_phi3.py ADDED
@@ -0,0 +1,1614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import warnings
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+ from transformers.activations import ACT2FN
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+ from transformers.modeling_attn_mask_utils import \
14
+ _prepare_4d_causal_attention_mask
15
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
16
+ CausalLMOutputWithPast,
17
+ SequenceClassifierOutputWithPast,
18
+ TokenClassifierOutput)
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.utils import (add_code_sample_docstrings,
21
+ add_start_docstrings,
22
+ add_start_docstrings_to_model_forward,
23
+ is_flash_attn_2_available,
24
+ is_flash_attn_greater_or_equal_2_10, logging,
25
+ replace_return_docstrings)
26
+
27
+ from .configuration_phi3 import Phi3Config
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
32
+ # if is_flash_attn_2_available():
33
+ _flash_supports_window_size = False
34
+ try:
35
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
36
+ from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
37
+ unpad_input)
38
+
39
+ _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters)
40
+ except ImportError as error:
41
+ logger.warning(
42
+ f'`flash-attention` package not found, consider installing for better performance: {error}.'
43
+ )
44
+ if not _flash_supports_window_size:
45
+ logger.warning(
46
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
47
+ )
48
+
49
+ _CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct'
50
+ _CONFIG_FOR_DOC = 'Phi3Config'
51
+
52
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
53
+ 'microsoft/Phi-3-mini-4k-instruct',
54
+ 'microsoft/Phi-3-mini-128k-instruct',
55
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
56
+ ]
57
+
58
+
59
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
60
+ class Phi3RMSNorm(nn.Module):
61
+ def __init__(self, hidden_size, eps=1e-6):
62
+ """
63
+ Phi3RMSNorm is equivalent to T5LayerNorm
64
+ """
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.variance_epsilon = eps
68
+
69
+ def forward(self, hidden_states):
70
+ input_dtype = hidden_states.dtype
71
+ hidden_states = hidden_states.to(torch.float32)
72
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+ return self.weight * hidden_states.to(input_dtype)
75
+
76
+
77
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
78
+ def _get_unpad_data(attention_mask):
79
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
80
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
81
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
82
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
83
+ return (
84
+ indices,
85
+ cu_seqlens,
86
+ max_seqlen_in_batch,
87
+ )
88
+
89
+
90
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
91
+ class Phi3RotaryEmbedding(nn.Module):
92
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
93
+ super().__init__()
94
+
95
+ self.dim = dim
96
+ self.max_position_embeddings = max_position_embeddings
97
+ self.base = base
98
+ self.register_buffer('inv_freq', None, persistent=False)
99
+
100
+ @torch.no_grad()
101
+ def forward(self, x, position_ids, seq_len=None):
102
+ # x: [bs, num_attention_heads, seq_len, head_size]
103
+ if self.inv_freq is None:
104
+ self.inv_freq = 1.0 / (
105
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
106
+ )
107
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
108
+ position_ids_expanded = position_ids[:, None, :].float()
109
+ # Force float32 since bfloat16 loses precision on long contexts
110
+ # See https://github.com/huggingface/transformers/pull/29285
111
+ device_type = x.device.type
112
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
113
+ with torch.autocast(device_type=device_type, enabled=False):
114
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
115
+ emb = torch.cat((freqs, freqs), dim=-1)
116
+ cos = emb.cos()
117
+ sin = emb.sin()
118
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
119
+
120
+
121
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
122
+ def __init__(self, dim, config, device=None):
123
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
124
+
125
+ self.short_factor = config.rope_scaling['short_factor']
126
+ self.long_factor = config.rope_scaling['long_factor']
127
+ self.original_max_position_embeddings = config.original_max_position_embeddings
128
+
129
+ @torch.no_grad()
130
+ def forward(self, x, position_ids, seq_len=None):
131
+ seq_len = torch.max(position_ids) + 1
132
+ if seq_len > self.original_max_position_embeddings:
133
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
134
+ else:
135
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
136
+
137
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
138
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
139
+
140
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
141
+ position_ids_expanded = position_ids[:, None, :].float()
142
+
143
+ # Force float32 since bfloat16 loses precision on long contexts
144
+ # See https://github.com/huggingface/transformers/pull/29285
145
+ device_type = x.device.type
146
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
147
+ with torch.autocast(device_type=device_type, enabled=False):
148
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
149
+ emb = torch.cat((freqs, freqs), dim=-1)
150
+
151
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
152
+ if scale <= 1.0:
153
+ scaling_factor = 1.0
154
+ else:
155
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
156
+
157
+ cos = emb.cos() * scaling_factor
158
+ sin = emb.sin() * scaling_factor
159
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
160
+
161
+
162
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
163
+ def __init__(self, dim, config, device=None):
164
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
165
+
166
+ self.short_factor = config.rope_scaling['short_factor']
167
+ self.long_factor = config.rope_scaling['long_factor']
168
+ self.original_max_position_embeddings = config.original_max_position_embeddings
169
+
170
+ @torch.no_grad()
171
+ def forward(self, x, position_ids, seq_len=None):
172
+ seq_len = torch.max(position_ids) + 1
173
+ if seq_len > self.original_max_position_embeddings:
174
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
175
+ else:
176
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
177
+
178
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
179
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
180
+
181
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
182
+ position_ids_expanded = position_ids[:, None, :].float()
183
+
184
+ # Force float32 since bfloat16 loses precision on long contexts
185
+ # See https://github.com/huggingface/transformers/pull/29285
186
+ device_type = x.device.type
187
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
188
+ with torch.autocast(device_type=device_type, enabled=False):
189
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
190
+ emb = torch.cat((freqs, freqs), dim=-1)
191
+
192
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
193
+ if scale <= 1.0:
194
+ scaling_factor = 1.0
195
+ else:
196
+ scaling_factor = 0.1 * math.log(scale) + 1.0
197
+
198
+ cos = emb.cos() * scaling_factor
199
+ sin = emb.sin() * scaling_factor
200
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
201
+
202
+
203
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
204
+ def rotate_half(x):
205
+ """Rotates half the hidden dims of the input."""
206
+ x1 = x[..., : x.shape[-1] // 2]
207
+ x2 = x[..., x.shape[-1] // 2 :]
208
+ return torch.cat((-x2, x1), dim=-1)
209
+
210
+
211
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
212
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
213
+ """Applies Rotary Position Embedding to the query and key tensors.
214
+
215
+ Args:
216
+ q (`torch.Tensor`): The query tensor.
217
+ k (`torch.Tensor`): The key tensor.
218
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
219
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
220
+ position_ids (`torch.Tensor`, *optional*):
221
+ Deprecated and unused.
222
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
223
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
224
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
225
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
226
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
227
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
228
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
229
+ Returns:
230
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
231
+ """
232
+ cos = cos.unsqueeze(unsqueeze_dim)
233
+ sin = sin.unsqueeze(unsqueeze_dim)
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ class Phi3MLP(nn.Module):
240
+ def __init__(self, config):
241
+ super().__init__()
242
+
243
+ self.config = config
244
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
245
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
246
+
247
+ self.activation_fn = ACT2FN[config.hidden_act]
248
+
249
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
250
+ up_states = self.gate_up_proj(hidden_states)
251
+
252
+ gate, up_states = up_states.chunk(2, dim=-1)
253
+ up_states = up_states * self.activation_fn(gate)
254
+
255
+ return self.down_proj(up_states)
256
+
257
+
258
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
259
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
260
+ """
261
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
262
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
263
+ """
264
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
265
+ if n_rep == 1:
266
+ return hidden_states
267
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
268
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
269
+
270
+
271
+ class Phi3Attention(nn.Module):
272
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
273
+
274
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
275
+ super().__init__()
276
+ self.config = config
277
+ self.layer_idx = layer_idx
278
+ if layer_idx is None:
279
+ logger.warning_once(
280
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
281
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
282
+ 'when creating this class.'
283
+ )
284
+
285
+ self.attention_dropout = config.attention_dropout
286
+ self.hidden_size = config.hidden_size
287
+ self.num_heads = config.num_attention_heads
288
+ self.head_dim = self.hidden_size // self.num_heads
289
+ self.num_key_value_heads = config.num_key_value_heads
290
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
291
+ self.max_position_embeddings = config.max_position_embeddings
292
+ self.original_max_position_embeddings = config.original_max_position_embeddings
293
+ self.rope_theta = config.rope_theta
294
+ self.rope_scaling = config.rope_scaling
295
+ self.is_causal = True
296
+
297
+ if (self.head_dim * self.num_heads) != self.hidden_size:
298
+ raise ValueError(
299
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
300
+ f' and `num_heads`: {self.num_heads}).'
301
+ )
302
+
303
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
304
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
305
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
306
+ self._init_rope()
307
+
308
+ def _init_rope(self):
309
+ if self.rope_scaling is None:
310
+ self.rotary_emb = Phi3RotaryEmbedding(
311
+ self.head_dim,
312
+ max_position_embeddings=self.max_position_embeddings,
313
+ base=self.rope_theta,
314
+ )
315
+ else:
316
+ scaling_type = self.config.rope_scaling['type']
317
+ if scaling_type == 'su':
318
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
319
+ elif scaling_type == 'yarn':
320
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
321
+ else:
322
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
323
+
324
+ def forward(
325
+ self,
326
+ hidden_states: torch.Tensor,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ position_ids: Optional[torch.LongTensor] = None,
329
+ past_key_value: Optional[Cache] = None,
330
+ output_attentions: bool = False,
331
+ use_cache: bool = False,
332
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
333
+ logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.')
334
+
335
+ bsz, q_len, _ = hidden_states.size()
336
+
337
+ qkv = self.qkv_proj(hidden_states)
338
+ query_pos = self.num_heads * self.head_dim
339
+ query_states = qkv[..., :query_pos]
340
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
341
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
342
+
343
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
344
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
346
+
347
+ kv_seq_len = key_states.shape[-2]
348
+ if past_key_value is not None:
349
+ if self.layer_idx is None:
350
+ raise ValueError(
351
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
352
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
353
+ 'with a layer index.'
354
+ )
355
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
356
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
357
+
358
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
359
+
360
+ if past_key_value is not None:
361
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
362
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
363
+
364
+ # repeat k/v heads if n_kv_heads < n_heads
365
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
366
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
367
+
368
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
369
+
370
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
371
+ raise ValueError(
372
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
373
+ f' {attn_weights.size()}'
374
+ )
375
+
376
+ if attention_mask is not None:
377
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
378
+ raise ValueError(
379
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
380
+ )
381
+ attn_weights = attn_weights + attention_mask
382
+
383
+ # upcast attention to fp32
384
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
385
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
386
+
387
+ attn_output = torch.matmul(attn_weights, value_states)
388
+
389
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
390
+ raise ValueError(
391
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
392
+ f' {attn_output.size()}'
393
+ )
394
+
395
+ attn_output = attn_output.transpose(1, 2).contiguous()
396
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
397
+
398
+ attn_output = self.o_proj(attn_output)
399
+
400
+ if not output_attentions:
401
+ attn_weights = None
402
+
403
+ return attn_output, attn_weights, past_key_value
404
+
405
+
406
+ class Phi3FlashAttention2(Phi3Attention):
407
+ """
408
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
409
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
410
+ flash attention and deal with padding tokens in case the input contains any of them.
411
+ """
412
+
413
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
414
+ def __init__(self, *args, **kwargs):
415
+ super().__init__(*args, **kwargs)
416
+
417
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
418
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
419
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
420
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.LongTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Cache] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ prop_index: int = 0,
431
+ **kwargs,
432
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
433
+ # Phi3FlashAttention2 attention does not support output_attentions
434
+
435
+ if not _flash_supports_window_size:
436
+ logger.warning_once(
437
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
438
+ )
439
+ raise ValueError('The current flash attention version does not support sliding window attention.')
440
+
441
+ output_attentions = False
442
+
443
+ if 'padding_mask' in kwargs:
444
+ warnings.warn(
445
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
446
+ )
447
+
448
+ # overwrite attention_mask with padding_mask
449
+ attention_mask = kwargs.pop('padding_mask')
450
+
451
+ bsz, q_len, _ = hidden_states.size()
452
+
453
+ qkv = self.qkv_proj(hidden_states)
454
+ query_pos = self.num_heads * self.head_dim
455
+ query_states = qkv[..., :query_pos]
456
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
457
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
458
+
459
+ # Flash attention requires the input to have the shape
460
+ # batch_size x seq_length x head_dim x hidden_dim
461
+ # therefore we just need to keep the original shape
462
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
463
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
464
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
465
+
466
+ kv_seq_len = key_states.shape[-2]
467
+ if past_key_value is not None and prop_index==0:
468
+ if self.layer_idx is None:
469
+ raise ValueError(
470
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
471
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
472
+ 'with a layer index.'
473
+ )
474
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
475
+
476
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
477
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
478
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
479
+
480
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
481
+
482
+ use_sliding_windows = (
483
+ _flash_supports_window_size
484
+ and getattr(self.config, 'sliding_window', None) is not None
485
+ and kv_seq_len > self.config.sliding_window
486
+ )
487
+
488
+ if past_key_value is not None and prop_index==0:
489
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
490
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
491
+ if (
492
+ getattr(self.config, 'sliding_window', None) is not None
493
+ and kv_seq_len > self.config.sliding_window
494
+ and cache_has_contents
495
+ ):
496
+ slicing_tokens = 1 - self.config.sliding_window
497
+
498
+ past_key = past_key_value[self.layer_idx][0]
499
+ past_value = past_key_value[self.layer_idx][1]
500
+
501
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
502
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
503
+
504
+ if past_key.shape[-2] != self.config.sliding_window - 1:
505
+ raise ValueError(
506
+ f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
507
+ f' {past_key.shape}'
508
+ )
509
+
510
+ if attention_mask is not None:
511
+ attention_mask = attention_mask[:, slicing_tokens:]
512
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
513
+
514
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
515
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
516
+
517
+ # repeat k/v heads if n_kv_heads < n_heads
518
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
519
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
520
+
521
+ attn_dropout = self.attention_dropout if self.training else 0.0
522
+
523
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
524
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
525
+ # cast them back in the correct dtype just to be sure everything works as expected.
526
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
527
+ # in fp32.
528
+
529
+ if query_states.dtype == torch.float32:
530
+ if torch.is_autocast_enabled():
531
+ target_dtype = torch.get_autocast_gpu_dtype()
532
+ # Handle the case where the model is quantized
533
+ elif hasattr(self.config, '_pre_quantization_dtype'):
534
+ target_dtype = self.config._pre_quantization_dtype
535
+ else:
536
+ target_dtype = self.qkv_proj.weight.dtype
537
+
538
+ logger.warning_once(
539
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
540
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
541
+ f' {target_dtype}.'
542
+ )
543
+
544
+ query_states = query_states.to(target_dtype)
545
+ key_states = key_states.to(target_dtype)
546
+ value_states = value_states.to(target_dtype)
547
+
548
+ # Reashape to the expected shape for Flash Attention
549
+ query_states = query_states.transpose(1, 2)
550
+ key_states = key_states.transpose(1, 2)
551
+ value_states = value_states.transpose(1, 2)
552
+
553
+ attn_output = self._flash_attention_forward(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ attention_mask,
558
+ q_len,
559
+ dropout=attn_dropout,
560
+ use_sliding_windows=use_sliding_windows,
561
+ )
562
+
563
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
564
+ attn_output = self.o_proj(attn_output)
565
+
566
+ if not output_attentions:
567
+ attn_weights = None
568
+
569
+ return attn_output, attn_weights, past_key_value
570
+
571
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
572
+ def _flash_attention_forward(
573
+ self,
574
+ query_states,
575
+ key_states,
576
+ value_states,
577
+ attention_mask,
578
+ query_length,
579
+ dropout=0.0,
580
+ softmax_scale=None,
581
+ use_sliding_windows=False,
582
+ ):
583
+ """
584
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
585
+ first unpad the input, then computes the attention scores and pad the final attention scores.
586
+
587
+ Args:
588
+ query_states (`torch.Tensor`):
589
+ Input query states to be passed to Flash Attention API
590
+ key_states (`torch.Tensor`):
591
+ Input key states to be passed to Flash Attention API
592
+ value_states (`torch.Tensor`):
593
+ Input value states to be passed to Flash Attention API
594
+ attention_mask (`torch.Tensor`):
595
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
596
+ position of padding tokens and 1 for the position of non-padding tokens.
597
+ dropout (`float`):
598
+ Attention dropout
599
+ softmax_scale (`float`, *optional*):
600
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
601
+ use_sliding_windows (`bool`, *optional*):
602
+ Whether to activate sliding window attention.
603
+ """
604
+ if not self._flash_attn_uses_top_left_mask:
605
+ causal = self.is_causal
606
+ else:
607
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
608
+ causal = self.is_causal and query_length != 1
609
+
610
+ # Contains at least one padding token in the sequence
611
+ if attention_mask is not None:
612
+ batch_size = query_states.shape[0]
613
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
614
+ query_states, key_states, value_states, attention_mask, query_length
615
+ )
616
+
617
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
618
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
619
+
620
+ if not use_sliding_windows:
621
+ attn_output_unpad = flash_attn_varlen_func(
622
+ query_states,
623
+ key_states,
624
+ value_states,
625
+ cu_seqlens_q=cu_seqlens_q,
626
+ cu_seqlens_k=cu_seqlens_k,
627
+ max_seqlen_q=max_seqlen_in_batch_q,
628
+ max_seqlen_k=max_seqlen_in_batch_k,
629
+ dropout_p=dropout,
630
+ softmax_scale=softmax_scale,
631
+ causal=causal,
632
+ )
633
+ else:
634
+ attn_output_unpad = flash_attn_varlen_func(
635
+ query_states,
636
+ key_states,
637
+ value_states,
638
+ cu_seqlens_q=cu_seqlens_q,
639
+ cu_seqlens_k=cu_seqlens_k,
640
+ max_seqlen_q=max_seqlen_in_batch_q,
641
+ max_seqlen_k=max_seqlen_in_batch_k,
642
+ dropout_p=dropout,
643
+ softmax_scale=softmax_scale,
644
+ causal=causal,
645
+ window_size=(self.config.sliding_window, self.config.sliding_window),
646
+ )
647
+
648
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
649
+ else:
650
+ if not use_sliding_windows:
651
+ attn_output = flash_attn_func(
652
+ query_states,
653
+ key_states,
654
+ value_states,
655
+ dropout,
656
+ softmax_scale=softmax_scale,
657
+ causal=causal,
658
+ )
659
+ else:
660
+ attn_output = flash_attn_func(
661
+ query_states,
662
+ key_states,
663
+ value_states,
664
+ dropout,
665
+ softmax_scale=softmax_scale,
666
+ causal=causal,
667
+ window_size=(self.config.sliding_window, self.config.sliding_window),
668
+ )
669
+
670
+ return attn_output
671
+
672
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
673
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
674
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
675
+
676
+ # On the first iteration we need to properly re-create the padding mask
677
+ # by slicing it on the proper place
678
+ if kv_seq_len != attention_mask.shape[-1]:
679
+ attention_mask_num_tokens = attention_mask.shape[-1]
680
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
681
+
682
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
683
+
684
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
685
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
686
+
687
+ if query_length == kv_seq_len:
688
+ query_layer = index_first_axis(
689
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
690
+ )
691
+ cu_seqlens_q = cu_seqlens_k
692
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
693
+ indices_q = indices_k
694
+ elif query_length == 1:
695
+ max_seqlen_in_batch_q = 1
696
+ cu_seqlens_q = torch.arange(
697
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
698
+ ) # There is a memcpy here, that is very bad.
699
+ indices_q = cu_seqlens_q[:-1]
700
+ query_layer = query_layer.squeeze(1)
701
+ else:
702
+ # The -q_len: slice assumes left padding.
703
+ attention_mask = attention_mask[:, -query_length:]
704
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
705
+
706
+ return (
707
+ query_layer,
708
+ key_layer,
709
+ value_layer,
710
+ indices_q,
711
+ (cu_seqlens_q, cu_seqlens_k),
712
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
713
+ )
714
+
715
+
716
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
717
+ # TODO @Arthur no longer copied from LLama after static cache
718
+ class Phi3SdpaAttention(Phi3Attention):
719
+ """
720
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
721
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
722
+ SDPA API.
723
+ """
724
+
725
+ # Adapted from Phi3Attention.forward
726
+ def forward(
727
+ self,
728
+ hidden_states: torch.Tensor,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ position_ids: Optional[torch.LongTensor] = None,
731
+ past_key_value: Optional[Cache] = None,
732
+ output_attentions: bool = False,
733
+ use_cache: bool = False,
734
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
735
+ if output_attentions:
736
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
737
+ logger.warning_once(
738
+ 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
739
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
740
+ )
741
+ return super().forward(
742
+ hidden_states=hidden_states,
743
+ attention_mask=attention_mask,
744
+ position_ids=position_ids,
745
+ past_key_value=past_key_value,
746
+ output_attentions=output_attentions,
747
+ use_cache=use_cache,
748
+ )
749
+
750
+ bsz, q_len, _ = hidden_states.size()
751
+
752
+ qkv = self.qkv_proj(hidden_states)
753
+ query_pos = self.num_heads * self.head_dim
754
+ query_states = qkv[..., :query_pos]
755
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
756
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
757
+
758
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
759
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
760
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
761
+
762
+ kv_seq_len = key_states.shape[-2]
763
+ if past_key_value is not None:
764
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
765
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
766
+
767
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
768
+
769
+ if past_key_value is not None:
770
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
771
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
772
+
773
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
774
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
775
+
776
+ if attention_mask is not None:
777
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
778
+ raise ValueError(
779
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
780
+ )
781
+
782
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
783
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
784
+ if query_states.device.type == 'cuda' and attention_mask is not None:
785
+ query_states = query_states.contiguous()
786
+ key_states = key_states.contiguous()
787
+ value_states = value_states.contiguous()
788
+
789
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
790
+ query_states,
791
+ key_states,
792
+ value_states,
793
+ attn_mask=attention_mask,
794
+ dropout_p=self.attention_dropout if self.training else 0.0,
795
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
796
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
797
+ )
798
+
799
+ attn_output = attn_output.transpose(1, 2).contiguous()
800
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
801
+
802
+ attn_output = self.o_proj(attn_output)
803
+
804
+ return attn_output, None, past_key_value
805
+
806
+
807
+ PHI3_ATTENTION_CLASSES = {
808
+ 'eager': Phi3Attention,
809
+ 'flash_attention_2': Phi3FlashAttention2,
810
+ 'sdpa': Phi3SdpaAttention,
811
+ }
812
+
813
+
814
+ class Phi3DecoderLayer(nn.Module):
815
+ def __init__(self, config: Phi3Config, layer_idx: int):
816
+ super().__init__()
817
+
818
+ self.config = config
819
+ self.self_attn = Phi3FlashAttention2(config, layer_idx=layer_idx)
820
+
821
+ self.mlp = Phi3MLP(config)
822
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
823
+
824
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
825
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
826
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
827
+
828
+ def forward(
829
+ self,
830
+ hidden_states: torch.Tensor,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ position_ids: Optional[torch.LongTensor] = None,
833
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
834
+ output_attentions: Optional[bool] = False,
835
+ use_cache: Optional[bool] = False,
836
+ prop_index: Optional[int] = 0,
837
+ **kwargs,
838
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
839
+ if 'padding_mask' in kwargs:
840
+ warnings.warn(
841
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
842
+ )
843
+ """
844
+ Args:
845
+ hidden_states (`torch.FloatTensor`):
846
+ input to the layer of shape `(batch, seq_len, embed_dim)`
847
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
848
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
849
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
850
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
851
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
852
+ output_attentions (`bool`, *optional*):
853
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
854
+ returned tensors for more detail.
855
+ use_cache (`bool`, *optional*):
856
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
857
+ (see `past_key_values`).
858
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
859
+ """
860
+
861
+ residual = hidden_states
862
+
863
+ hidden_states = self.input_layernorm(hidden_states)
864
+
865
+ # Self Attention
866
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
867
+ hidden_states=hidden_states,
868
+ attention_mask=attention_mask,
869
+ position_ids=position_ids,
870
+ past_key_value=past_key_value,
871
+ output_attentions=output_attentions,
872
+ use_cache=use_cache,
873
+ prop_index=prop_index
874
+ )
875
+
876
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
877
+
878
+ residual = hidden_states
879
+ hidden_states = self.post_attention_layernorm(hidden_states)
880
+ hidden_states = self.mlp(hidden_states)
881
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
882
+
883
+ outputs = (hidden_states,)
884
+
885
+ if output_attentions:
886
+ outputs += (self_attn_weights,)
887
+
888
+ if use_cache:
889
+ outputs += (present_key_value,)
890
+
891
+ return outputs
892
+
893
+
894
+ PHI3_START_DOCSTRING = r"""
895
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
896
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
897
+ etc.)
898
+
899
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
900
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
901
+ and behavior.
902
+
903
+ Parameters:
904
+ config ([`Phi3Config`]):
905
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
906
+ load the weights associated with the model, only the configuration. Check out the
907
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
908
+ """
909
+
910
+
911
+ @add_start_docstrings(
912
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
913
+ PHI3_START_DOCSTRING,
914
+ )
915
+ class Phi3PreTrainedModel(PreTrainedModel):
916
+ config_class = Phi3Config
917
+ base_model_prefix = 'model'
918
+ supports_gradient_checkpointing = True
919
+ _no_split_modules = ['Phi3DecoderLayer']
920
+ _skip_keys_device_placement = 'past_key_values'
921
+ _supports_flash_attn_2 = True
922
+ _supports_sdpa = False
923
+ _supports_cache_class = True
924
+
925
+ _version = '0.0.5'
926
+
927
+ def _init_weights(self, module):
928
+ std = self.config.initializer_range
929
+ if isinstance(module, nn.Linear):
930
+ module.weight.data.normal_(mean=0.0, std=std)
931
+ if module.bias is not None:
932
+ module.bias.data.zero_()
933
+ elif isinstance(module, nn.Embedding):
934
+ module.weight.data.normal_(mean=0.0, std=std)
935
+ if module.padding_idx is not None:
936
+ module.weight.data[module.padding_idx].zero_()
937
+
938
+
939
+ PHI3_INPUTS_DOCSTRING = r"""
940
+ Args:
941
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
942
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
943
+ it.
944
+
945
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
946
+ [`PreTrainedTokenizer.__call__`] for details.
947
+
948
+ [What are input IDs?](../glossary#input-ids)
949
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
950
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
951
+
952
+ - 1 for tokens that are **not masked**,
953
+ - 0 for tokens that are **masked**.
954
+
955
+ [What are attention masks?](../glossary#attention-mask)
956
+
957
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
958
+ [`PreTrainedTokenizer.__call__`] for details.
959
+
960
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
961
+ `past_key_values`).
962
+
963
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
964
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
965
+ information on the default strategy.
966
+
967
+ - 1 indicates the head is **not masked**,
968
+ - 0 indicates the head is **masked**.
969
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
970
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
971
+ config.n_positions - 1]`.
972
+
973
+ [What are position IDs?](../glossary#position-ids)
974
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
975
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
976
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
977
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
978
+
979
+ Two formats are allowed:
980
+ - a [`~cache_utils.Cache`] instance;
981
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
982
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
983
+ cache format.
984
+
985
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
986
+ legacy cache format will be returned.
987
+
988
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
989
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
990
+ of shape `(batch_size, sequence_length)`.
991
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
992
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
993
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
994
+ model's internal embedding lookup matrix.
995
+ use_cache (`bool`, *optional*):
996
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
997
+ `past_key_values`).
998
+ output_attentions (`bool`, *optional*):
999
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1000
+ tensors for more detail.
1001
+ output_hidden_states (`bool`, *optional*):
1002
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1003
+ more detail.
1004
+ return_dict (`bool`, *optional*):
1005
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1006
+ """
1007
+
1008
+
1009
+ @add_start_docstrings(
1010
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
1011
+ PHI3_START_DOCSTRING,
1012
+ )
1013
+ class Phi3Model(Phi3PreTrainedModel):
1014
+ """
1015
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1016
+
1017
+ Args:
1018
+ config: Phi3Config
1019
+ """
1020
+
1021
+ def __init__(self, config: Phi3Config):
1022
+ super().__init__(config)
1023
+ self.padding_idx = config.pad_token_id
1024
+ self.vocab_size = config.vocab_size
1025
+
1026
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1027
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1028
+ self.layers = nn.ModuleList(
1029
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1030
+ )
1031
+ self._attn_implementation = "flash_attention_2"
1032
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1033
+
1034
+ self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
1035
+ self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
1036
+
1037
+ self.gradient_checkpointing = False
1038
+ # Initialize weights and apply final processing
1039
+ self.post_init()
1040
+
1041
+ def get_input_embeddings(self):
1042
+ return self.embed_tokens
1043
+
1044
+ def set_input_embeddings(self, value):
1045
+ self.embed_tokens = value
1046
+
1047
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1048
+ def forward(
1049
+ self,
1050
+ input_ids: torch.LongTensor = None,
1051
+ attention_mask: Optional[torch.Tensor] = None,
1052
+ position_ids: Optional[torch.LongTensor] = None,
1053
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1054
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1055
+ use_cache: Optional[bool] = None,
1056
+ output_attentions: Optional[bool] = None,
1057
+ output_hidden_states: Optional[bool] = None,
1058
+ return_dict: Optional[bool] = None,
1059
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1060
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1061
+ output_hidden_states = (
1062
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1063
+ )
1064
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1065
+
1066
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1067
+
1068
+ # retrieve input_ids and inputs_embeds
1069
+ if input_ids is not None and inputs_embeds is not None:
1070
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
1071
+ elif input_ids is not None:
1072
+ batch_size, seq_length = input_ids.shape[:2]
1073
+ elif inputs_embeds is not None:
1074
+ batch_size, seq_length = inputs_embeds.shape[:2]
1075
+ else:
1076
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
1077
+
1078
+ past_key_values_length = 0
1079
+
1080
+ if self.gradient_checkpointing and self.training:
1081
+ if use_cache:
1082
+ logger.warning_once(
1083
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
1084
+ )
1085
+ use_cache = False
1086
+
1087
+ if use_cache:
1088
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1089
+ if use_legacy_cache:
1090
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1091
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1092
+
1093
+ if position_ids is None:
1094
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1095
+ position_ids = torch.arange(
1096
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1097
+ )
1098
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1099
+ else:
1100
+ position_ids = position_ids.view(-1, seq_length).long()
1101
+
1102
+ if inputs_embeds is None:
1103
+ inputs_embeds = self.embed_tokens(input_ids)
1104
+
1105
+ if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
1106
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1107
+ if is_padding_right:
1108
+ raise ValueError(
1109
+ "You are attempting to perform batched generation with padding_side='right'"
1110
+ ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to '
1111
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1112
+ )
1113
+
1114
+ if self._attn_implementation == 'flash_attention_2':
1115
+ # 2d mask is passed through the layers
1116
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1117
+ else:
1118
+ # 4d mask is passed through the layers
1119
+ attention_mask = _prepare_4d_causal_attention_mask(
1120
+ attention_mask,
1121
+ (batch_size, seq_length),
1122
+ inputs_embeds,
1123
+ past_key_values_length,
1124
+ sliding_window=self.config.sliding_window,
1125
+ )
1126
+
1127
+ hidden_states = inputs_embeds
1128
+
1129
+ # decoder layers
1130
+ all_hidden_states = () if output_hidden_states else None
1131
+ all_self_attns = () if output_attentions else None
1132
+ next_decoder_cache = None
1133
+
1134
+ for idx, decoder_layer in enumerate(self.layers):
1135
+ if output_hidden_states:
1136
+ all_hidden_states += (hidden_states,)
1137
+
1138
+ if self.gradient_checkpointing and self.training:
1139
+ # TroL reusing
1140
+ original_hidden_states_list = []
1141
+ for _ in range(2):
1142
+ layer_outputs = self._gradient_checkpointing_func(
1143
+ decoder_layer.__call__,
1144
+ hidden_states,
1145
+ attention_mask,
1146
+ position_ids,
1147
+ past_key_values,
1148
+ output_attentions,
1149
+ use_cache,
1150
+ )
1151
+ hidden_states = layer_outputs[0]
1152
+ original_hidden_states_list.append(layer_outputs[0])
1153
+ # Second TroL Gating & Feature Merging
1154
+ trol_score = self.trol_function(original_hidden_states_list[0], idx)
1155
+ updated_hidden_states = original_hidden_states_list[0] * (1 - trol_score) + original_hidden_states_list[1] * trol_score
1156
+
1157
+ else:
1158
+ if hidden_states.shape[1] > 1:
1159
+ # TroL reusing
1160
+ original_hidden_states_list = []
1161
+ for idxidx in range(2):
1162
+ layer_outputs = decoder_layer(
1163
+ hidden_states,
1164
+ attention_mask=attention_mask,
1165
+ position_ids=position_ids,
1166
+ past_key_value=past_key_values,
1167
+ output_attentions=output_attentions,
1168
+ use_cache=use_cache,
1169
+ prop_index=idxidx
1170
+ )
1171
+ hidden_states = layer_outputs[0]
1172
+ original_hidden_states_list.append(layer_outputs[0])
1173
+ # Second TroL Gating & Feature Merging
1174
+ trol_score = self.trol_function(original_hidden_states_list[0], idx)
1175
+ updated_hidden_states = original_hidden_states_list[0] * (1-trol_score) + original_hidden_states_list[1] * trol_score
1176
+ else:
1177
+ # TroL reusing
1178
+ layer_outputs = decoder_layer(
1179
+ hidden_states,
1180
+ attention_mask=attention_mask,
1181
+ position_ids=position_ids,
1182
+ past_key_value=past_key_values,
1183
+ output_attentions=output_attentions,
1184
+ use_cache=use_cache,
1185
+ )
1186
+ updated_hidden_states = layer_outputs[0]
1187
+
1188
+ # hidden_states = layer_outputs[0]
1189
+ hidden_states = updated_hidden_states
1190
+
1191
+ if use_cache:
1192
+ next_decoder_cache = layer_outputs[1]
1193
+
1194
+ if output_attentions:
1195
+ all_self_attns += (layer_outputs[1],)
1196
+
1197
+ hidden_states = self.norm(hidden_states)
1198
+
1199
+ # add hidden states from the last decoder layer
1200
+ if output_hidden_states:
1201
+ all_hidden_states += (hidden_states,)
1202
+
1203
+ next_cache = None
1204
+ if use_cache:
1205
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1206
+ if not return_dict:
1207
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1208
+ return BaseModelOutputWithPast(
1209
+ last_hidden_state=hidden_states,
1210
+ past_key_values=next_cache,
1211
+ hidden_states=all_hidden_states,
1212
+ attentions=all_self_attns,
1213
+ )
1214
+
1215
+
1216
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1217
+ _tied_weights_keys = ['lm_head.weight']
1218
+
1219
+ def __init__(self, config):
1220
+ super().__init__(config)
1221
+ self.model = Phi3Model(config)
1222
+ self.vocab_size = config.vocab_size
1223
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1224
+
1225
+ # Initialize weights and apply final processing
1226
+ self.post_init()
1227
+
1228
+ def get_input_embeddings(self):
1229
+ return self.model.embed_tokens
1230
+
1231
+ def set_input_embeddings(self, value):
1232
+ self.model.embed_tokens = value
1233
+
1234
+ def get_output_embeddings(self):
1235
+ return self.lm_head
1236
+
1237
+ def set_output_embeddings(self, new_embeddings):
1238
+ self.lm_head = new_embeddings
1239
+
1240
+ def set_decoder(self, decoder):
1241
+ self.model = decoder
1242
+
1243
+ def get_decoder(self):
1244
+ return self.model
1245
+
1246
+ # Ignore copy
1247
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1248
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1249
+ def forward(
1250
+ self,
1251
+ input_ids: torch.LongTensor = None,
1252
+ attention_mask: Optional[torch.Tensor] = None,
1253
+ position_ids: Optional[torch.LongTensor] = None,
1254
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1256
+ labels: Optional[torch.LongTensor] = None,
1257
+ use_cache: Optional[bool] = None,
1258
+ output_attentions: Optional[bool] = None,
1259
+ output_hidden_states: Optional[bool] = None,
1260
+ return_dict: Optional[bool] = None,
1261
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1262
+ r"""
1263
+ Args:
1264
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1265
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1266
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1267
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1268
+
1269
+ Returns:
1270
+
1271
+ Example:
1272
+
1273
+ ```python
1274
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1275
+
1276
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1277
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1278
+
1279
+ >>> prompt = "This is an example script ."
1280
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1281
+
1282
+ >>> # Generate
1283
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1284
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1285
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1286
+ ```"""
1287
+
1288
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1289
+ output_hidden_states = (
1290
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1291
+ )
1292
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1293
+
1294
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1295
+ outputs = self.model(
1296
+ input_ids=input_ids,
1297
+ attention_mask=attention_mask,
1298
+ position_ids=position_ids,
1299
+ past_key_values=past_key_values,
1300
+ inputs_embeds=inputs_embeds,
1301
+ use_cache=use_cache,
1302
+ output_attentions=output_attentions,
1303
+ output_hidden_states=output_hidden_states,
1304
+ return_dict=return_dict,
1305
+ )
1306
+
1307
+ hidden_states = outputs[0]
1308
+ logits = self.lm_head(hidden_states)
1309
+ logits = logits.float()
1310
+
1311
+ loss = None
1312
+ if labels is not None:
1313
+ # Shift so that tokens < n predict n
1314
+ shift_logits = logits[..., :-1, :].contiguous()
1315
+ shift_labels = labels[..., 1:].contiguous()
1316
+ # Flatten the tokens
1317
+ loss_fct = CrossEntropyLoss()
1318
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1319
+ shift_labels = shift_labels.view(-1)
1320
+ # Enable model parallelism
1321
+ shift_labels = shift_labels.to(shift_logits.device)
1322
+ loss = loss_fct(shift_logits, shift_labels)
1323
+
1324
+ if not return_dict:
1325
+ output = (logits,) + outputs[1:]
1326
+ return (loss,) + output if loss is not None else output
1327
+
1328
+ return CausalLMOutputWithPast(
1329
+ loss=loss,
1330
+ logits=logits,
1331
+ past_key_values=outputs.past_key_values,
1332
+ hidden_states=outputs.hidden_states,
1333
+ attentions=outputs.attentions,
1334
+ )
1335
+
1336
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1337
+ def prepare_inputs_for_generation(
1338
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1339
+ ):
1340
+ if past_key_values is not None:
1341
+ if isinstance(past_key_values, Cache):
1342
+ cache_length = past_key_values.get_seq_length()
1343
+ past_length = past_key_values.seen_tokens
1344
+ max_cache_length = past_key_values.get_max_length()
1345
+ else:
1346
+ cache_length = past_length = past_key_values[0][0].shape[2]
1347
+ max_cache_length = None
1348
+
1349
+ # Keep only the unprocessed tokens:
1350
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1351
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1352
+ # input)
1353
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1354
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1355
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1356
+ # input_ids based on the past_length.
1357
+ elif past_length < input_ids.shape[1]:
1358
+ input_ids = input_ids[:, past_length:]
1359
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1360
+
1361
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1362
+ if (
1363
+ max_cache_length is not None
1364
+ and attention_mask is not None
1365
+ and cache_length + input_ids.shape[1] > max_cache_length
1366
+ ):
1367
+ attention_mask = attention_mask[:, -max_cache_length:]
1368
+
1369
+ position_ids = kwargs.get('position_ids', None)
1370
+ if attention_mask is not None and position_ids is None:
1371
+ # create position_ids on the fly for batch generation
1372
+ position_ids = attention_mask.long().cumsum(-1) - 1
1373
+ position_ids.masked_fill_(attention_mask == 0, 1)
1374
+ if past_key_values:
1375
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1376
+
1377
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1378
+ if inputs_embeds is not None and past_key_values is None:
1379
+ model_inputs = {'inputs_embeds': inputs_embeds}
1380
+ else:
1381
+ model_inputs = {'input_ids': input_ids}
1382
+
1383
+ model_inputs.update(
1384
+ {
1385
+ 'position_ids': position_ids,
1386
+ 'past_key_values': past_key_values,
1387
+ 'use_cache': kwargs.get('use_cache'),
1388
+ 'attention_mask': attention_mask,
1389
+ }
1390
+ )
1391
+ return model_inputs
1392
+
1393
+ @staticmethod
1394
+ def _reorder_cache(past_key_values, beam_idx):
1395
+ reordered_past = ()
1396
+ for layer_past in past_key_values:
1397
+ reordered_past += (
1398
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1399
+ )
1400
+ return reordered_past
1401
+
1402
+
1403
+ @add_start_docstrings(
1404
+ """
1405
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1406
+
1407
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1408
+ (e.g. GPT-2) do.
1409
+
1410
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1411
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1412
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1413
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1414
+ each row of the batch).
1415
+ """,
1416
+ PHI3_START_DOCSTRING,
1417
+ )
1418
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1419
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1420
+ def __init__(self, config):
1421
+ super().__init__(config)
1422
+ self.num_labels = config.num_labels
1423
+ self.model = Phi3Model(config)
1424
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1425
+
1426
+ # Initialize weights and apply final processing
1427
+ self.post_init()
1428
+
1429
+ def get_input_embeddings(self):
1430
+ return self.model.embed_tokens
1431
+
1432
+ def set_input_embeddings(self, value):
1433
+ self.model.embed_tokens = value
1434
+
1435
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1436
+ def forward(
1437
+ self,
1438
+ input_ids: torch.LongTensor = None,
1439
+ attention_mask: Optional[torch.Tensor] = None,
1440
+ position_ids: Optional[torch.LongTensor] = None,
1441
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1442
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1443
+ labels: Optional[torch.LongTensor] = None,
1444
+ use_cache: Optional[bool] = None,
1445
+ output_attentions: Optional[bool] = None,
1446
+ output_hidden_states: Optional[bool] = None,
1447
+ return_dict: Optional[bool] = None,
1448
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1449
+ r"""
1450
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1451
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1452
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1453
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1454
+ """
1455
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1456
+
1457
+ model_outputs = self.model(
1458
+ input_ids,
1459
+ attention_mask=attention_mask,
1460
+ position_ids=position_ids,
1461
+ past_key_values=past_key_values,
1462
+ inputs_embeds=inputs_embeds,
1463
+ use_cache=use_cache,
1464
+ output_attentions=output_attentions,
1465
+ output_hidden_states=output_hidden_states,
1466
+ return_dict=return_dict,
1467
+ )
1468
+ hidden_states = model_outputs[0]
1469
+ logits = self.score(hidden_states)
1470
+
1471
+ if input_ids is not None:
1472
+ batch_size = input_ids.shape[0]
1473
+ else:
1474
+ batch_size = inputs_embeds.shape[0]
1475
+
1476
+ if self.config.pad_token_id is None and batch_size != 1:
1477
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1478
+ if self.config.pad_token_id is None:
1479
+ sequence_lengths = -1
1480
+ else:
1481
+ if input_ids is not None:
1482
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1483
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1484
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1485
+ sequence_lengths = sequence_lengths.to(logits.device)
1486
+ else:
1487
+ sequence_lengths = -1
1488
+
1489
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1490
+
1491
+ loss = None
1492
+ if labels is not None:
1493
+ labels = labels.to(logits.device)
1494
+ if self.config.problem_type is None:
1495
+ if self.num_labels == 1:
1496
+ self.config.problem_type = 'regression'
1497
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1498
+ self.config.problem_type = 'single_label_classification'
1499
+ else:
1500
+ self.config.problem_type = 'multi_label_classification'
1501
+
1502
+ if self.config.problem_type == 'regression':
1503
+ loss_fct = MSELoss()
1504
+ if self.num_labels == 1:
1505
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1506
+ else:
1507
+ loss = loss_fct(pooled_logits, labels)
1508
+ elif self.config.problem_type == 'single_label_classification':
1509
+ loss_fct = CrossEntropyLoss()
1510
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1511
+ elif self.config.problem_type == 'multi_label_classification':
1512
+ loss_fct = BCEWithLogitsLoss()
1513
+ loss = loss_fct(pooled_logits, labels)
1514
+ if not return_dict:
1515
+ output = (pooled_logits,) + model_outputs[1:]
1516
+ return ((loss,) + output) if loss is not None else output
1517
+
1518
+ return SequenceClassifierOutputWithPast(
1519
+ loss=loss,
1520
+ logits=pooled_logits,
1521
+ past_key_values=model_outputs.past_key_values,
1522
+ hidden_states=model_outputs.hidden_states,
1523
+ attentions=model_outputs.attentions,
1524
+ )
1525
+
1526
+
1527
+ @add_start_docstrings(
1528
+ """
1529
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1530
+ Named-Entity-Recognition (NER) tasks.
1531
+ """,
1532
+ PHI3_START_DOCSTRING,
1533
+ )
1534
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1535
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1536
+ def __init__(self, config: Phi3Config):
1537
+ super().__init__(config)
1538
+ self.num_labels = config.num_labels
1539
+
1540
+ self.model = Phi3Model(config)
1541
+ if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None:
1542
+ classifier_dropout = config.classifier_dropout
1543
+ elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None:
1544
+ classifier_dropout = config.hidden_dropout
1545
+ else:
1546
+ classifier_dropout = 0.1
1547
+ self.dropout = nn.Dropout(classifier_dropout)
1548
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1549
+
1550
+ # Initialize weights and apply final processing
1551
+ self.post_init()
1552
+
1553
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1554
+ @add_code_sample_docstrings(
1555
+ checkpoint=_CHECKPOINT_FOR_DOC,
1556
+ output_type=TokenClassifierOutput,
1557
+ config_class=_CONFIG_FOR_DOC,
1558
+ )
1559
+ def forward(
1560
+ self,
1561
+ input_ids: Optional[torch.LongTensor] = None,
1562
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1563
+ attention_mask: Optional[torch.Tensor] = None,
1564
+ inputs_embeds: Optional[torch.Tensor] = None,
1565
+ labels: Optional[torch.Tensor] = None,
1566
+ use_cache: Optional[bool] = None,
1567
+ output_attentions: Optional[bool] = None,
1568
+ output_hidden_states: Optional[bool] = None,
1569
+ return_dict: Optional[bool] = None,
1570
+ **deprecated_arguments,
1571
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1572
+ r"""
1573
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1574
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1575
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1576
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1577
+ """
1578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1579
+
1580
+ model_outputs = self.model(
1581
+ input_ids,
1582
+ past_key_values=past_key_values,
1583
+ attention_mask=attention_mask,
1584
+ inputs_embeds=inputs_embeds,
1585
+ use_cache=use_cache,
1586
+ output_attentions=output_attentions,
1587
+ output_hidden_states=output_hidden_states,
1588
+ return_dict=return_dict,
1589
+ )
1590
+
1591
+ hidden_states = model_outputs[0]
1592
+ hidden_states = self.dropout(hidden_states)
1593
+ logits = self.classifier(hidden_states)
1594
+
1595
+ loss = None
1596
+ if labels is not None:
1597
+ # move labels to correct device to enable model parallelism
1598
+ labels = labels.to(logits.device)
1599
+ batch_size, seq_length = labels.shape
1600
+ loss_fct = CrossEntropyLoss()
1601
+ loss = loss_fct(
1602
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1603
+ )
1604
+
1605
+ if not return_dict:
1606
+ output = (logits,) + model_outputs[2:]
1607
+ return ((loss,) + output) if loss is not None else output
1608
+
1609
+ return TokenClassifierOutput(
1610
+ loss=loss,
1611
+ logits=logits,
1612
+ hidden_states=model_outputs.hidden_states,
1613
+ attentions=model_outputs.attentions,
1614
+ )
trol/arch_phi3/modeling_trol.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # System
2
+ import torch
3
+ from torch import nn
4
+ from utils.utils import *
5
+ import torch.utils.checkpoint
6
+ from typing import List, Optional, Tuple, Union
7
+ from transformers.modeling_utils import PreTrainedModel
8
+
9
+ # trol file
10
+ from .modeling_intern_vit import InternVisionModel
11
+ from .modeling_phi3 import Phi3ForCausalLM
12
+
13
+ # Dataclass & ModelOutput
14
+ from dataclasses import dataclass
15
+ from transformers.modeling_outputs import ModelOutput
16
+
17
+
18
+ # Configuration
19
+ ########################################################################################
20
+ import copy
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from .configuration_intern_vit import InternVisionConfig
23
+ from .configuration_phi3 import Phi3Config
24
+
25
+ class TroLConfig(PretrainedConfig):
26
+ model_type = 'trol'
27
+ is_composition = True
28
+
29
+ def __init__(
30
+ self,
31
+ vision_config=None,
32
+ llm_config=None,
33
+ use_backbone_lora=0,
34
+ use_llm_lora=0,
35
+ pad2square=False,
36
+ select_layer=-1,
37
+ force_image_size=None,
38
+ downsample_ratio=0.5,
39
+ template=None,
40
+ dynamic_image_size=False,
41
+ use_thumbnail=False,
42
+ ps_version='v1',
43
+ min_dynamic_patch=1,
44
+ max_dynamic_patch=6,
45
+ **kwargs):
46
+ super().__init__(**kwargs)
47
+ self.vision_config = InternVisionConfig(**vision_config)
48
+ self.llm_config = Phi3Config(**llm_config)
49
+ self.use_backbone_lora = use_backbone_lora
50
+ self.use_llm_lora = use_llm_lora
51
+ self.pad2square = pad2square
52
+ self.select_layer = select_layer
53
+ self.force_image_size = force_image_size
54
+ self.downsample_ratio = downsample_ratio
55
+ self.template = template
56
+ self.dynamic_image_size = dynamic_image_size
57
+ self.use_thumbnail = use_thumbnail
58
+ self.ps_version = ps_version # pixel shuffle version
59
+ self.min_dynamic_patch = min_dynamic_patch
60
+ self.max_dynamic_patch = max_dynamic_patch
61
+
62
+ def to_dict(self):
63
+ output = copy.deepcopy(self.__dict__)
64
+ output['vision_config'] = self.vision_config.to_dict()
65
+ output['llm_config'] = self.llm_config.to_dict()
66
+ output['model_type'] = self.__class__.model_type
67
+ output['use_backbone_lora'] = self.use_backbone_lora
68
+ output['use_llm_lora'] = self.use_llm_lora
69
+ output['pad2square'] = self.pad2square
70
+ output['select_layer'] = self.select_layer
71
+ output['force_image_size'] = self.force_image_size
72
+ output['downsample_ratio'] = self.downsample_ratio
73
+ output['template'] = self.template
74
+ output['dynamic_image_size'] = self.dynamic_image_size
75
+ output['use_thumbnail'] = self.use_thumbnail
76
+ output['ps_version'] = self.ps_version
77
+ output['min_dynamic_patch'] = self.min_dynamic_patch
78
+ output['max_dynamic_patch'] = self.max_dynamic_patch
79
+ return output
80
+ ########################################################################################
81
+
82
+ @dataclass
83
+ class TroLCausalLMOutputWithPast(ModelOutput):
84
+ loss: Optional[torch.FloatTensor] = None
85
+ logits: torch.FloatTensor = None
86
+ past_key_values: Optional[List[torch.FloatTensor]] = None
87
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
88
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
89
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
90
+
91
+ class TroLForCausalLM(PreTrainedModel):
92
+ config_class = TroLConfig
93
+
94
+ def __init__(self, config):
95
+ super().__init__(config)
96
+
97
+ image_size = config.force_image_size or config.vision_config.image_size
98
+ patch_size = config.vision_config.patch_size
99
+ self.patch_size = patch_size
100
+ self.select_layer = config.select_layer
101
+ self.template = config.template
102
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
103
+ self.downsample_ratio = config.downsample_ratio
104
+ self.ps_version = config.ps_version
105
+
106
+ self.vision_model = InternVisionModel(config.vision_config)
107
+ self.language_model = Phi3ForCausalLM(config.llm_config)
108
+ self.prompt_rule = {"system_start": "<s><|system|>\n",
109
+ "system_end": "<|end|>",
110
+ "user_start": "<|user|>\n",
111
+ "user_end": "<|end|>",
112
+ "assistant_start": "<|assistant|>\n",
113
+ "assistant_end": "<|end|>\n</s>",
114
+ "test_start": "<|assistant|>\n",
115
+ "test_end": "<|end|>",
116
+ "split": "\n",
117
+ }
118
+
119
+ vit_hidden_size = config.vision_config.hidden_size
120
+ llm_hidden_size = config.llm_config.hidden_size
121
+
122
+ self.vision_proj = nn.Sequential(
123
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
124
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
125
+ nn.GELU(),
126
+ nn.Linear(llm_hidden_size, llm_hidden_size)
127
+ )
128
+
129
+ def extract_feature(self, pixel_values):
130
+ self.vision_model.eval()
131
+ vit_embeds = self.vision_model(
132
+ pixel_values=pixel_values,
133
+ output_hidden_states=False,
134
+ return_dict=True).last_hidden_state
135
+ vit_embeds = vit_embeds[:, 1:, :]
136
+ h = w = int(vit_embeds.shape[1] ** 0.5)
137
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
138
+ vit_embeds = pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
139
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
140
+ return vit_embeds
141
+
142
+ def eval_process(
143
+ self,
144
+ inputs,
145
+ data,
146
+ tokenizer,
147
+ device,
148
+ img_token_number,
149
+ ):
150
+ batched_image = []
151
+ batched_qa_prompt=[]
152
+ for _input in inputs:
153
+
154
+ # Visualization
155
+ # imim = _input['image'].cpu().permute(1, 2, 0)
156
+
157
+ # adding <image> to question if not included despite being an image, and adding system prompt and <tor> prompt
158
+ if 'image' in _input.keys() and not '<image>' in _input['question']: _input['question'] = '<image>\n' + _input['question']
159
+
160
+ # making image prompt
161
+ if 'image' in _input.keys() and _input['image'] != None:
162
+ process_image = dynamic_preprocess(_input['image'].to(device))
163
+ dynamic_process_image = torch.stack([dynamic_transform(image) for image in process_image]).to(device)
164
+ img_token_number = dynamic_process_image.shape[0] * 256
165
+ batched_image.append(dynamic_process_image)
166
+
167
+ # make question and answer
168
+ question = make_instruction(_input['question'], data, self.prompt_rule)
169
+
170
+ # adding image special tokens to question
171
+ if 'image' in _input.keys(): question = question.replace('<image>', '<img><IMG_CONTEXT></img>')
172
+
173
+ # add bundle image tokens if it has <IMG_CONTEXT> token
174
+ question = add_bundle_tokens(question, '<IMG_CONTEXT>', img_token_number)
175
+
176
+ batched_qa_prompt.append(question)
177
+
178
+ '''For Final Outputs'''
179
+ qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False)
180
+
181
+ # [1] input_ids
182
+ input_ids = qa_prompts.input_ids.to(device)
183
+
184
+ # [2] attention_mask
185
+ attention_mask = qa_prompts.attention_mask.to(device)
186
+
187
+ if len(batched_image):
188
+ return {"input_ids": input_ids,
189
+ "attention_mask": attention_mask,
190
+ "image_features": self.extract_feature(torch.cat(batched_image, dim=0).to(device))
191
+ }
192
+ else:
193
+ return {"input_ids": input_ids,
194
+ "attention_mask": attention_mask,
195
+ }
196
+
197
+ def _merge_input_embeds_with_image_features(self, image_features, inputs_embeds, input_ids):
198
+ B, N, C = inputs_embeds.shape
199
+ input_ids = input_ids.reshape(B * N)
200
+ inputs_embeds = inputs_embeds.reshape(B * N, C)
201
+ selected = torch.where(input_ids == self.config.image_token_index)
202
+ assert selected[0].sum() != 0
203
+ inputs_embeds[selected] = image_features.reshape(-1, C).to(inputs_embeds.device)
204
+ inputs_embeds = inputs_embeds.reshape(B, N, C)
205
+ return inputs_embeds
206
+
207
+ def forward(
208
+ self,
209
+ input_ids: torch.LongTensor = None,
210
+ image_features: torch.FloatTensor = None,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ position_ids: Optional[torch.LongTensor] = None,
213
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
214
+ inputs_embeds: Optional[torch.FloatTensor] = None,
215
+ labels: Optional[torch.LongTensor] = None,
216
+ use_cache: Optional[bool] = None,
217
+ output_attentions: Optional[bool] = None,
218
+ output_hidden_states: Optional[bool] = None,
219
+ return_dict: Optional[bool] = None,
220
+ ) -> Union[Tuple, TroLCausalLMOutputWithPast]:
221
+
222
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
223
+ output_hidden_states = (
224
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
225
+ )
226
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
227
+
228
+ if inputs_embeds is None:
229
+ # 1. Extra the input embeddings
230
+ try:
231
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids).requires_grad_(False)
232
+ except:
233
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
234
+
235
+ # 2. Merge text and images
236
+ if image_features is not None and input_ids.shape[1] != 1:
237
+
238
+ image_features = self.vision_proj(image_features.to(inputs_embeds.dtype))
239
+ inputs_embeds = self._merge_input_embeds_with_image_features(image_features, inputs_embeds, input_ids)
240
+
241
+ # In case input_ids.shape[1] == 1 & image_features==None & past_key_values != None, we are in the case of
242
+ # generation with cache
243
+ elif past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
244
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
245
+ # that are set to 0
246
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
247
+
248
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
249
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
250
+
251
+ # Get the target length
252
+ target_length = input_ids.shape[1]
253
+ past_length = first_layer_past_key_value.shape[-1]
254
+
255
+ extended_attention_mask = torch.ones(
256
+ (attention_mask.shape[0], past_length),
257
+ dtype=attention_mask.dtype,
258
+ device=attention_mask.device,
259
+ )
260
+
261
+ # Filter out only the tokens that can be un-attended, this can happen
262
+ # if one uses Llava + Fused modules where the cache on the
263
+ # first iteration is already big enough, or if one passes custom cache
264
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
265
+ new_batch_index = batch_index[valid_indices]
266
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
267
+
268
+ # Zero-out the places where we don't need to attend
269
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
270
+
271
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
272
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
273
+
274
+ outputs = self.language_model(
275
+ attention_mask=attention_mask,
276
+ position_ids=position_ids,
277
+ past_key_values=past_key_values,
278
+ inputs_embeds=inputs_embeds,
279
+ use_cache=use_cache,
280
+ output_attentions=output_attentions,
281
+ output_hidden_states=output_hidden_states,
282
+ return_dict=return_dict,
283
+ )
284
+ logits = outputs.logits
285
+
286
+ loss = None
287
+ if labels is not None:
288
+ # Shift so that tokens < n predict n
289
+ if attention_mask is not None:
290
+ shift_attention_mask = attention_mask[..., 1:]
291
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
292
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
293
+ else:
294
+ shift_logits = logits[..., :-1, :].contiguous()
295
+ shift_labels = labels[..., 1:].contiguous()
296
+ # Flatten the tokens
297
+ loss_fct = nn.CrossEntropyLoss()
298
+ loss = loss_fct(
299
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
300
+ )
301
+
302
+ if not return_dict:
303
+ output = (logits,) + outputs[1:]
304
+ return (loss,) + output if loss is not None else output
305
+
306
+ return TroLCausalLMOutputWithPast(
307
+ loss=loss,
308
+ logits=logits,
309
+ past_key_values=outputs.past_key_values,
310
+ hidden_states=outputs.hidden_states,
311
+ attentions=outputs.attentions,
312
+ )
313
+
314
+ @torch.no_grad()
315
+ def generate(
316
+ self,
317
+ image_features: Optional[torch.FloatTensor] = None,
318
+ input_ids: Optional[torch.FloatTensor] = None,
319
+ attention_mask: Optional[torch.LongTensor] = None,
320
+ **generate_kwargs,
321
+ ) -> torch.LongTensor:
322
+
323
+ assert self.config.image_token_index is not None
324
+ if image_features is not None:
325
+ vit_embeds = self.vision_proj(image_features)
326
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
327
+ B, N, C = input_embeds.shape
328
+ input_embeds = input_embeds.reshape(B * N, C)
329
+
330
+ input_ids = input_ids.reshape(B * N)
331
+ selected = (input_ids == self.config.image_token_index)
332
+ assert selected.sum() != 0
333
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
334
+
335
+ input_embeds = input_embeds.reshape(B, N, C)
336
+ else:
337
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
338
+
339
+ outputs = self.language_model.generate(
340
+ inputs_embeds=input_embeds,
341
+ attention_mask=attention_mask,
342
+ eos_token_id=self.config.eos_token_id,
343
+ **generate_kwargs,
344
+ )
345
+
346
+ return outputs
trol/load_trol.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from config import *
4
+ from peft import LoraConfig
5
+ from transformers import BitsAndBytesConfig
6
+
7
+ warnings.filterwarnings(action='ignore')
8
+
9
+ def load_trol(link):
10
+
11
+ """
12
+ model selection
13
+ """
14
+ if link == 'TroL-1.8B':
15
+ from .arch_internlm2.modeling_trol import TroLForCausalLM
16
+ from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
+ bits = 4
18
+ path = TROL_1_8B
19
+ bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
20
+
21
+ elif link == 'TroL-3.8B':
22
+ from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
+ from transformers import LlamaTokenizerFast as TroLTokenizer
24
+ bits = 8
25
+ path = TROL_3_8B
26
+ bit_quant_skip = ["vision_model", "mlp1", "lm_head"]
27
+
28
+ elif link == 'TroL-7B':
29
+ from .arch_internlm2.modeling_trol import TroLForCausalLM
30
+ from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
31
+ bits = 4
32
+ path = TROL_7B
33
+ bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
34
+ else:
35
+ raise Exception("Unsupported Link")
36
+
37
+ # huggingface model configuration
38
+ huggingface_config = {}
39
+
40
+ # Bit quantization
41
+ if bits in [4, 8]:
42
+ huggingface_config.update(dict(
43
+ torch_dtype=torch.float16,
44
+ low_cpu_mem_usage=True,
45
+ attn_implementation="flash_attention_2",
46
+ quantization_config=BitsAndBytesConfig(
47
+ load_in_4bit=bits == 4,
48
+ load_in_8bit=bits == 8,
49
+ llm_int8_skip_modules=bit_quant_skip,
50
+ llm_int8_threshold=6.0,
51
+ llm_int8_has_fp16_weight=False,
52
+ bnb_4bit_compute_dtype=torch.float16,
53
+ bnb_4bit_use_double_quant=True,
54
+ bnb_4bit_quant_type='nf4'
55
+ )
56
+ ))
57
+ else:
58
+ huggingface_config.update(dict(
59
+ torch_dtype=torch.float16,
60
+ low_cpu_mem_usage=True,
61
+ attn_implementation="flash_attention_2",
62
+ ))
63
+
64
+ # Loading tokenizer & Loading backbone model (error -> then delete flash attention)
65
+ tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left')
66
+ try:
67
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
68
+ except:
69
+ del huggingface_config["attn_implementation"]
70
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
71
+ return trol, tok_trol
utils/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+
5
+ output_filtering = lambda x, model: x.split(model.prompt_rule["test_start"])[-1].split(model.prompt_rule["test_end"])[0].strip()
6
+
7
+ def memory_optimization():
8
+ # memory deallocation
9
+ gc.collect()
10
+
11
+ # removing cache
12
+ torch.cuda.empty_cache()
13
+
14
+ def str2bool(v):
15
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
16
+ return True
17
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
18
+ return False
19
+ else:
20
+ assert False
21
+
22
+ def freeze_model(model):
23
+ for param in model.parameters():
24
+ param.requires_grad=False
25
+
26
+ def switching_model(model, updating_param):
27
+ if updating_param == 'all':
28
+ for name, param in model.named_parameters():
29
+ param.requires_grad=True
30
+ return
31
+
32
+ for name, param in model.named_parameters():
33
+ if 'float' in str(param.dtype):
34
+ if sum([up_param in name for up_param in updating_param]):
35
+ param.requires_grad=True
36
+ else:
37
+ param.requires_grad=False
38
+
39
+ def weight_upload(tensor_dict, model):
40
+ used_name = []
41
+ for name, param in tensor_dict.items():
42
+ split_name = name.split('.')
43
+
44
+ traversal = model
45
+ for module_name in split_name:
46
+ traversal = getattr(traversal, module_name)
47
+ # logging
48
+ # print(f'{name}: {(traversal==param.to(traversal.device)).sum()}/{(traversal!=param.to(traversal.device)).sum()}')
49
+ setattr(traversal, 'data', param.to(traversal.device))
50
+ used_name.append(name)
51
+
52
+ for name in used_name:
53
+ del tensor_dict[name]
54
+
55
+ def find_special_token(string, special_token):
56
+ start = 0
57
+ while True:
58
+ start = string.find(special_token, start)
59
+ if start == -1: return
60
+ yield start
61
+ start += len(special_token) # use start += 1 to find overlapping matches
62
+
63
+ def add_bundle_tokens(input_string, special_token, num):
64
+
65
+ # number of special tokens in input_string
66
+ num_special_tokens = len(list(find_special_token(input_string, special_token)))
67
+
68
+ # No special token -> return the raw
69
+ if not num_special_tokens:
70
+ return input_string
71
+
72
+ result = ""
73
+ index = 0
74
+ while index < len(input_string):
75
+ if input_string[index:index + len(special_token)] == special_token:
76
+ result += special_token * num
77
+ index += len(special_token)
78
+ else:
79
+ result += input_string[index]
80
+ index += 1
81
+
82
+ assert len(list(find_special_token(result, special_token))) == num_special_tokens * num
83
+ return result
84
+
85
+ def make_instruction(question, dataset, prompt_rule):
86
+ system_prompt = make_human_string("You are AI model created by Byung-Kwan Lee, Ph.D. candidate, KAIST EE, of which AI model name is TroL (Traversal of Layers).",
87
+ "You must give helpful, detailed, and polite answers to the user's questions",
88
+ split=' ')
89
+
90
+ if dataset != "mmmu" and dataset != "mathverse" and dataset != "hallusionbench" and dataset != "demo":
91
+ question = "<image>" + question
92
+
93
+ if dataset in ["sqa", "mmbench", "mmbench_cn", "mmbench_dev", "mmbench_cn_dev", "seed", "qbench", "ai2d", "mmstar"]:
94
+ question = question + "\nAnswer with the option's letter from the given choices directly."
95
+
96
+ elif dataset in ["vqav2", "gqa", "pope", "chartqa"]:
97
+ question = question + "\nAnswer the question using a single word or phrase."
98
+
99
+ elif dataset in ["vizwiz"]:
100
+ question = question + "\nWhen the provided information is insufficient, respond with 'Unanswerable'. Answer the question using a single word or phrase."
101
+
102
+ elif dataset in ["mmmu"]:
103
+ if "A." in question:
104
+ question = question + "\nAnswer with the option's letter from the given choices directly."
105
+ else:
106
+ question = question + "\nAnswer the question using a single word or phrase."
107
+
108
+ elif dataset in ["hallusionbench"]:
109
+ if "Please answer yes or no." not in question:
110
+ question = question + "\nPlease answer yes or no."
111
+
112
+ qa_prompt = make_human_string(prompt_rule["system_start"]+system_prompt+prompt_rule["system_end"],
113
+ prompt_rule["user_start"]+question+prompt_rule["user_end"],
114
+ prompt_rule["assistant_start"],
115
+ split=prompt_rule["split"])
116
+
117
+ return qa_prompt
118
+
119
+ def make_human_string(*args, split):
120
+ out = ''
121
+ for i, arg in enumerate(args):
122
+ out += arg
123
+ if i != len(args)-1:
124
+ out += split
125
+ return out
126
+
127
+ def get_max_new_tokens(data_name):
128
+ if data_name.lower() in ["mme", "pope", "sqa", "mmbench", "mmbench_cn", "mmbench_dev","mmbench_cn_dev", "seed", "qbench", "ai2d", "mmstar", "vqav2", "gqa", "chartqa", "hallusionbench", "textvqa", "mmmu"]:
129
+ return 5
130
+ if data_name.lower() in ["llava", "mm-vet"]:
131
+ return 1024
132
+ else:
133
+ return 512
134
+
135
+ def pixel_shuffle(x, scale_factor=0.5):
136
+ n, w, h, c = x.size()
137
+ # N, W, H, C --> N, W, H * scale, C // scale
138
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
139
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
140
+ x = x.permute(0, 2, 1, 3).contiguous()
141
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
142
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
143
+ int(c / (scale_factor * scale_factor)))
144
+ x = x.permute(0, 2, 1, 3).contiguous()
145
+ return x
146
+
147
+ import torchvision.transforms as T
148
+ from torchvision.transforms.functional import InterpolationMode
149
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
150
+ IMAGENET_STD = (0.229, 0.224, 0.225)
151
+ def build_transform(input_size):
152
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
153
+ transform = T.Compose([
154
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
155
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
156
+ T.ToTensor(),
157
+ T.Normalize(mean=MEAN, std=STD)
158
+ ])
159
+ return transform
160
+ dynamic_transform = build_transform(input_size=448)
161
+
162
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
163
+ best_ratio_diff = float('inf')
164
+ best_ratio = (1, 1)
165
+ area = width * height
166
+ for ratio in target_ratios:
167
+ target_aspect_ratio = ratio[0] / ratio[1]
168
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
169
+ if ratio_diff < best_ratio_diff:
170
+ best_ratio_diff = ratio_diff
171
+ best_ratio = ratio
172
+ elif ratio_diff == best_ratio_diff:
173
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
174
+ best_ratio = ratio
175
+ return best_ratio
176
+
177
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
178
+ from torchvision.transforms.functional import to_pil_image
179
+ image = to_pil_image(image)
180
+ orig_width, orig_height = image.size
181
+ aspect_ratio = orig_width / orig_height
182
+
183
+ # calculate the existing image aspect ratio
184
+ target_ratios = set(
185
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
186
+ i * j <= max_num and i * j >= min_num)
187
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
188
+
189
+ # find the closest aspect ratio to the target
190
+ target_aspect_ratio = find_closest_aspect_ratio(
191
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
192
+
193
+ # calculate the target width and height
194
+ target_width = image_size * target_aspect_ratio[0]
195
+ target_height = image_size * target_aspect_ratio[1]
196
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
197
+
198
+ # resize the image
199
+ resized_img = image.resize((target_width, target_height))
200
+ processed_images = []
201
+ for i in range(blocks):
202
+ box = (
203
+ (i % (target_width // image_size)) * image_size,
204
+ (i // (target_width // image_size)) * image_size,
205
+ ((i % (target_width // image_size)) + 1) * image_size,
206
+ ((i // (target_width // image_size)) + 1) * image_size
207
+ )
208
+ # split the image
209
+ split_img = resized_img.crop(box)
210
+ processed_images.append(split_img)
211
+ assert len(processed_images) == blocks
212
+ if use_thumbnail and len(processed_images) != 1:
213
+ thumbnail_img = image.resize((image_size, image_size))
214
+ processed_images.append(thumbnail_img)
215
+ return processed_images