Xianbin commited on
Commit
881b143
1 Parent(s): 83f8193

Update instruct model to latest weights

Browse files
README.md CHANGED
@@ -5,26 +5,11 @@ license: mit
5
 
6
  SEA-LION is a collection of Large Language Models (LLMs) which has been pretrained and instruct-tuned for the Southeast Asia (SEA) region.
7
  The size of the models range from 3 billion to 7 billion parameters.
8
- This is the card for the SEA-LION 7B Instruct (Commercial) model.
9
 
10
- For more details on the base model, please refer to the [base model's model card](https://huggingface.co/aisingapore/sealion7b).
 
11
 
12
- SEA-LION stands for <i>Southeast Asian Languages In One Network</i>.
13
-
14
-
15
- ## Model Details
16
-
17
- ### Model Description
18
-
19
- The SEA-LION model is a significant leap forward in the field of Natural Language Processing,
20
- specifically trained to understand the SEA regional context.
21
-
22
- SEA-LION is built on the robust MPT architecture and has a vocabulary size of 256K.
23
-
24
- For tokenization, the model employs our custom SEABPETokenizer, which is specially tailored for SEA languages, ensuring optimal model performance.
25
-
26
- The pre-training data for the base SEA-LION model encompasses 980B tokens.
27
- The model was then further instruction-tuned on a mixture of <b>commercially-permissive English and Indonesian data</b>.
28
 
29
  - **Developed by:** Products Pillar, AI Singapore
30
  - **Funded by:** Singapore NRF
@@ -32,19 +17,37 @@ The model was then further instruction-tuned on a mixture of <b>commercially-per
32
  - **Languages:** English, Chinese, Indonesian, Malay, Thai, Vietnamese, Filipino, Tamil, Burmese, Khmer, Lao
33
  - **License:** MIT License
34
 
 
 
 
 
35
  ### Benchmark Performance
 
 
 
 
 
36
 
37
- Coming soon.
 
 
 
 
 
 
 
 
 
38
 
39
- ### Usage and limitations
40
  SEA-LION can be run using the 🤗 Transformers library
41
  ```python
42
  # Please use transformers==4.37.2
43
 
44
  from transformers import AutoModelForCausalLM, AutoTokenizer
45
 
46
- tokenizer = AutoTokenizer.from_pretrained("aisingapore/sealion7b-instruct-c", trust_remote_code=True)
47
- model = AutoModelForCausalLM.from_pretrained("aisingapore/sealion7b-instruct-c", trust_remote_code=True)
48
 
49
  prompt_template = "### USER:\n{human_prompt}\n\n### RESPONSE:\n"
50
  prompt = """Apa sentimen dari kalimat berikut ini?
@@ -57,34 +60,40 @@ output = model.generate(tokens["input_ids"], max_new_tokens=20, eos_token_id=tok
57
  print(tokenizer.decode(output[0], skip_special_tokens=True))
58
 
59
  ```
 
60
 
61
- ## Technical Specifications
 
 
 
62
 
63
- ### Model Architecture and Objective
 
64
 
65
- SEA-LION is a decoder model using the MPT architecture.
66
 
67
- | Parameter | SEA-LION 7B |
68
- |-----------------|:-----------:|
69
- | Layers | 32 |
70
- | d_model | 4096 |
71
- | head_dim | 32 |
72
- | Vocabulary | 256000 |
73
- | Sequence Length | 2048 |
 
74
 
75
- ### Tokenizer Details
 
76
 
77
- We sample 20M lines from the training data to train the tokenizer.<br>
78
- The framework for training is [SentencePiece](https://github.com/google/sentencepiece).<br>
79
- The tokenizer type is Byte-Pair Encoding (BPE).
80
 
81
- ### Training Details
82
 
83
- Coming soon.
 
84
 
85
  ## The Team
86
 
87
- Lam Wen Zhi Clarence<br>
88
  Leong Wei Qi<br>
89
  Li Yier<br>
90
  Liu Bing Jie Darius<br>
@@ -95,10 +104,11 @@ Ngui Jian Gang<br>
95
  Nguyen Thanh Ngan<br>
96
  Ong Tat-Wee David<br>
97
  Rengarajan Hamsawardhini<br>
 
98
  Susanto Yosephine<br>
99
  Tai Ngee Chia<br>
100
  Tan Choon Meng<br>
101
- Teo Jin Howe<br>
102
  Teo Eng Sipp Leslie<br>
103
  Teo Wei Yi<br>
104
  Tjhi William<br>
@@ -107,8 +117,7 @@ Yong Xianbin<br>
107
 
108
  ## Acknowledgements
109
 
110
- AI Singapore is a national programme supported by the National Research Foundation, Singapore and hosted by the National University of Singapore.
111
- Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not reflect the views of National Research Foundation, Singapore.
112
 
113
  ## Contact
114
 
 
5
 
6
  SEA-LION is a collection of Large Language Models (LLMs) which has been pretrained and instruct-tuned for the Southeast Asia (SEA) region.
7
  The size of the models range from 3 billion to 7 billion parameters.
 
8
 
9
+ SEA-LION-7B-Instruct is a multilingual model which has been fine-tuned with **thousands of English and Indonesian instruction-completion pairs** alongside a smaller pool of instruction-completion pairs from other ASEAN languages.
10
+ These instructions have been carefully curated and rewritten to ensure the model is trained on truly open, commercially permissive and high quality datasets.
11
 
12
+ SEA-LION stands for _Southeast Asian Languages In One Network_.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  - **Developed by:** Products Pillar, AI Singapore
15
  - **Funded by:** Singapore NRF
 
17
  - **Languages:** English, Chinese, Indonesian, Malay, Thai, Vietnamese, Filipino, Tamil, Burmese, Khmer, Lao
18
  - **License:** MIT License
19
 
20
+ ## Model Details
21
+ ### Base model
22
+ We perform instruction tuning in English and Indonesian on our [pre-trained SEA-LION-7B](https://huggingface.co/aisingapore/sealion7b), a decoder model using the MPT architecture, to create SEA-LION-7B-Instruct.
23
+
24
  ### Benchmark Performance
25
+ We evaluated SEA-LION-7B-Instruct on the BHASA benchmark ([arXiv](https://arxiv.org/abs/2309.06085v2) and [GitHub](https://github.com/aisingapore/bhasa)) across a variety of tasks.
26
+
27
+ BHASA stands out amongst other evaluations for SEA languages for its holistic approach to evaluation, including not just traditional Natural Language Processing (NLP) benchmarking tasks (such as sentiment analysis and question answering), but also linguistic and cultural diagnostic tests which are meticulously handcrafted.
28
+
29
+ The scores shown in the table below have been adjusted to only consider answers provided in the appropriate language.
30
 
31
+ | Model | QA (F1) | Sentiment (F1) | Toxicity (F1) | Eng>Indo (ChrF++) | Indo>Eng (ChrF++) | Summary (ROUGE-L) | NLI (Acc) | Causal (Acc) |
32
+ |--------------------------------|---------|----------------|---------------|-------------------|-------------------|-------------------|-----------|--------------|
33
+ | SEA-LION-7B-Instruct-Research | 24.86 | 76.13 | 24.45 | 52.50 | 46.82 | 15.44 | 33.20 | 23.80 |
34
+ | SEA-LION-7B-Instruct | 68.41 | 91.45 | 17.98 | 57.48 | 58.04 | 17.54 | 53.10 | 60.80 |
35
+ | SeaLLM 7B v1 | 30.96 | 56.29 | 22.60 | 62.23 | 41.55 | 14.03 | 26.50 | 56.60 |
36
+ | SeaLLM 7B v2 | 44.40 | 80.13 | 55.24 | 64.01 | 63.28 | 17.31 | 43.60 | 82.00 |
37
+ | Sailor-7B | 65.43 | 59.48 | 20.48 | 64.27 | 60.68 | 8.69 | 15.10 | 38.40 |
38
+ | Llama 2 7B Chat | 11.12 | 52.32 | 0.00 | 44.09 | 57.58 | 9.24 | 0.00 | 0.00 |
39
+ | Mistral 7B Instruct v0.1 | 38.85 | 74.38 | 20.83 | 30.60 | 51.43 | 15.63 | 28.60 | 50.80 |
40
+ | GPT-4 | 73.60 | 74.14 | 63.96 | 69.38 | 67.53 | 18.71 | 83.20 | 96.00 |
41
 
42
+ ### Usage
43
  SEA-LION can be run using the 🤗 Transformers library
44
  ```python
45
  # Please use transformers==4.37.2
46
 
47
  from transformers import AutoModelForCausalLM, AutoTokenizer
48
 
49
+ tokenizer = AutoTokenizer.from_pretrained("aisingapore/sealion7b-instruct", trust_remote_code=True)
50
+ model = AutoModelForCausalLM.from_pretrained("aisingapore/sealion7b-instruct", trust_remote_code=True)
51
 
52
  prompt_template = "### USER:\n{human_prompt}\n\n### RESPONSE:\n"
53
  prompt = """Apa sentimen dari kalimat berikut ini?
 
60
  print(tokenizer.decode(output[0], skip_special_tokens=True))
61
 
62
  ```
63
+ ### Prompting Guide
64
 
65
+ _Coming soon_
66
+
67
+ ### Caveats
68
+ It is important for users to be aware that our model exhibits certain limitations that warrant consideration. Firstly, like many LLMs, the model can hallucinate and occasionally generates irrelevant content, introducing fictional elements that are not grounded in the provided context. Users should also exercise caution in interpreting and validating the model's responses due to the potential inconsistencies in its reasoning. Finally, it should be noted that the model has not been optimized for multi-turn dialogue interactions, which may result in reduced effectiveness in extended conversations.
69
 
70
+ ## Limitations
71
+ ### Safety
72
 
73
+ Current SEA-LION models, including this commercially permissive release, have not been aligned for safety. Developers and users should perform their own safety fine-tuning and related security measures. In no event shall the authors be held liable for any claim, damages, or other liability arising from the use of the released weights and codes.
74
 
75
+ ### Commercially Non-Permissive and Commercially Permissive SEA-LION Releases
76
+
77
+ The previous release of the commercially non-permissive SEA-LION-Instruct-Research enabled us to explore the full research potential of SEA-LION when allowed to take full advantage of what is publicly available. In contrast, in building the commercially permissive SEA-LION-7B-Instruct, we had to leave out high-quality instruction data that was either proprietary, restricted by non-commercial licenses or in a legal gray area, leaving us with a much smaller proportion of commercially permissive data to work with — a problem that is even more pronounced for low-resource languages. We thus hope this will sound a call to action for more initiatives to create commercially viable data in the region, enabling practical benefits for all.
78
+
79
+
80
+ ## Technical Specifications
81
+ ### Fine-Tuning Details
82
+ The SEA-LION-7B-Instruct was fine-tuned using 8x A100-40GB using parameter efficient fine tuning in the form of LoRA.
83
 
84
+ ## Data
85
+ SEA-LION-7B-Instruct was trained on a wide range of instructions that were manually and stringently verified by our team. A large portion of the effort was dedicated to ensuring that each instruction-completion pair that the model sees is of a high quality and any errors were corrected and rewritten by native speakers or else dropped from our mix.
86
 
87
+ In addition, special care was taken to ensure that the datasets used had commercially permissive licenses through verification with the original data source.
 
 
88
 
89
+ Link to dataset: _coming soon_
90
 
91
+ ## Call for Contributions
92
+ We encourage researchers, developers, and language enthusiasts to actively contribute to the enhancement and expansion of SEA-LION. Contributions can involve identifying and reporting bugs, sharing pre-training, instruction, and preference data, improving documentation usability, proposing and implementing new model evaluation tasks and metrics, or training versions of the model in additional Southeast Asian languages. Join us in shaping the future of SEA-LION by sharing your expertise and insights to make these models more accessible, accurate, and versatile. Please check out our GitHub for further information on the call for contributions.
93
 
94
  ## The Team
95
 
96
+ Lau Wayne<br>
97
  Leong Wei Qi<br>
98
  Li Yier<br>
99
  Liu Bing Jie Darius<br>
 
104
  Nguyen Thanh Ngan<br>
105
  Ong Tat-Wee David<br>
106
  Rengarajan Hamsawardhini<br>
107
+ Siow Bryan<br>
108
  Susanto Yosephine<br>
109
  Tai Ngee Chia<br>
110
  Tan Choon Meng<br>
111
+ Teng Walter<br>
112
  Teo Eng Sipp Leslie<br>
113
  Teo Wei Yi<br>
114
  Tjhi William<br>
 
117
 
118
  ## Acknowledgements
119
 
120
+ [AI Singapore](​​https://aisingapore.org/) is a national programme supported by the National Research Foundation, Singapore and hosted by the National University of Singapore. Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not reflect the views of the National Research Foundation or the National University of Singapore.
 
121
 
122
  ## Contact
123
 
adapt_tokenizer.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import Any
2
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
 
3
  NUM_SENTINEL_TOKENS: int = 100
4
 
 
5
  def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
6
  """Adds sentinel tokens and padding token (if missing).
7
 
@@ -11,16 +13,17 @@ def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
11
  All added tokens are added as special tokens. No tokens are
12
  added if sentinel tokens and padding token already exist.
13
  """
14
- sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
15
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
16
  if tokenizer.pad_token is None:
17
- tokenizer.add_tokens('<pad>', special_tokens=True)
18
- tokenizer.pad_token = '<pad>'
19
  assert tokenizer.pad_token_id is not None
20
- sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
21
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
22
  tokenizer.sentinel_token_ids = _sentinel_token_ids
23
 
 
24
  class AutoTokenizerForMOD(AutoTokenizer):
25
  """AutoTokenizer + Adaptation for MOD.
26
 
@@ -37,4 +40,4 @@ class AutoTokenizerForMOD(AutoTokenizer):
37
  """See `AutoTokenizer.from_pretrained` docstring."""
38
  tokenizer = super().from_pretrained(*args, **kwargs)
39
  adapt_tokenizer_for_denoising(tokenizer)
40
- return tokenizer
 
1
  from typing import Any
2
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
3
+
4
  NUM_SENTINEL_TOKENS: int = 100
5
 
6
+
7
  def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
8
  """Adds sentinel tokens and padding token (if missing).
9
 
 
13
  All added tokens are added as special tokens. No tokens are
14
  added if sentinel tokens and padding token already exist.
15
  """
16
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
17
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
18
  if tokenizer.pad_token is None:
19
+ tokenizer.add_tokens("<pad>", special_tokens=True)
20
+ tokenizer.pad_token = "<pad>"
21
  assert tokenizer.pad_token_id is not None
22
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
23
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
24
  tokenizer.sentinel_token_ids = _sentinel_token_ids
25
 
26
+
27
  class AutoTokenizerForMOD(AutoTokenizer):
28
  """AutoTokenizer + Adaptation for MOD.
29
 
 
40
  """See `AutoTokenizer.from_pretrained` docstring."""
41
  tokenizer = super().from_pretrained(*args, **kwargs)
42
  adapt_tokenizer_for_denoising(tokenizer)
43
+ return tokenizer
added_tokens.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "<unk>": 0,
3
- "<|endofline|>": 2,
4
- "<|endoftext|>": 1,
5
- "<|padding|>": 3
6
- }
 
 
 
 
 
 
 
attention.py CHANGED
@@ -1,37 +1,63 @@
1
  """Attention layers."""
 
2
  import math
3
  import warnings
4
- from typing import Any, List, Optional, Tuple
5
  import torch
6
  import torch.nn as nn
 
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
  from .fc import FC_CLASS_REGISTRY
11
  from .norm import NORM_CLASS_REGISTRY
12
 
13
- def is_flash_v2_installed():
 
 
14
  try:
15
  import flash_attn as flash_attn
16
  except:
17
  return False
18
- return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
 
19
 
20
  def is_flash_v1_installed():
21
  try:
22
  import flash_attn as flash_attn
23
  except:
24
  return False
25
- return version.parse(flash_attn.__version__) < version.parse('2.0.0')
 
 
 
 
 
 
 
 
 
26
 
27
- def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool:
 
 
 
 
 
 
 
 
 
28
  if original_is_causal and num_query_tokens != num_key_tokens:
29
  if num_query_tokens != 1:
30
- raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
 
 
31
  else:
32
  return False
33
  return original_is_causal
34
 
 
35
  def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
36
  """Perform repeat of kv heads along a particular dimension.
37
 
@@ -45,16 +71,27 @@ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
45
  hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
46
  return hidden.reshape(b, s, kv_n_heads * n_rep, d)
47
 
48
- def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
49
- if multiquery:
50
- warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
51
- kv_n_heads = 1
52
- elif kv_n_heads is None:
53
- warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
54
- kv_n_heads = n_heads
55
- q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
56
- k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
57
- v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
 
 
 
 
 
 
 
 
 
 
 
58
  if past_key_value is not None:
59
  if len(past_key_value) != 0:
60
  k = torch.cat([past_key_value[0], k], dim=3)
@@ -72,14 +109,28 @@ def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tenso
72
  _s_q = max(0, attn_bias.size(2) - s_q)
73
  _s_k = max(0, attn_bias.size(3) - s_k)
74
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
75
- if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
76
- raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
 
 
 
 
 
 
77
  attn_weight = attn_weight + attn_bias
78
  min_val = torch.finfo(q.dtype).min
79
  if key_padding_mask is not None:
80
  if attn_bias is not None:
81
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
82
- attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
 
 
 
 
 
 
 
 
83
  if is_causal and (not q.size(2) == 1):
84
  s = max(s_q, s_k)
85
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
@@ -90,92 +141,195 @@ def scaled_multihead_dot_product_attention(query: torch.Tensor, key: torch.Tenso
90
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
91
  attn_weight = torch.softmax(attn_weight, dim=-1)
92
  if dropout_p:
93
- attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
 
 
94
  out = attn_weight.to(v.dtype).matmul(v)
95
- out = rearrange(out, 'b h s d -> b s (h d)')
96
  if needs_weights:
97
  return (out, attn_weight, past_key_value)
98
  return (out, None, past_key_value)
99
 
100
- def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[List[torch.dtype]]=None):
 
 
 
101
  if valid_dtypes is None:
102
  valid_dtypes = [torch.float16, torch.bfloat16]
103
  for tensor in tensors:
104
  if tensor.dtype not in valid_dtypes:
105
- raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
 
 
106
  if not tensor.is_cuda:
107
- raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
 
 
108
 
109
- def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
  from flash_attn import bert_padding, flash_attn_interface
112
  except:
113
- raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.2')
114
  check_valid_inputs(query, key, value)
115
- if multiquery:
116
- warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
117
- kv_n_heads = 1
118
- elif kv_n_heads is None:
119
- warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
120
- kv_n_heads = n_heads
121
  if past_key_value is not None:
122
  if len(past_key_value) != 0:
123
  key = torch.cat([past_key_value[0], key], dim=1)
124
  value = torch.cat([past_key_value[1], value], dim=1)
125
  past_key_value = (key, value)
126
  if attn_bias is not None:
127
- _s_q = max(0, attn_bias.size(2) - query.size(1))
128
- _s_k = max(0, attn_bias.size(3) - key.size(1))
129
- attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
- if attn_bias is not None:
131
- raise NotImplementedError(f'attn_bias not implemented for flash attn.')
132
  (batch_size, seqlen) = query.shape[:2]
133
- if key_padding_mask is None:
134
- key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
135
- query_padding_mask = key_padding_mask[:, -query.size(1):]
136
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
137
- query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
138
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
139
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
140
- (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
141
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
142
- if kv_n_heads == 1:
143
- key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
144
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
145
- elif kv_n_heads < n_heads:
146
- key_unpad = repeat_kv_for_gqa(key_unpad.view(batch_size, seqlen, kv_n_heads, -1), n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
147
- value_unpad = repeat_kv_for_gqa(value_unpad.view(batch_size, seqlen, kv_n_heads, -1), n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  dropout_p = dropout_p if training else 0.0
149
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
150
  if is_flash_v1_installed():
151
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
 
 
 
 
 
 
 
 
 
 
 
 
152
  elif is_flash_v2_installed():
153
- output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  else:
155
- raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
156
- output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
 
 
157
  return (output, None, past_key_value)
158
 
159
- def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  try:
161
  from .flash_attn_triton import flash_attn_func
162
  except:
163
  _installed = False
164
- if version.parse(torch.__version__) < version.parse('2.0.0'):
165
  _installed = True
166
  try:
167
  from flash_attn.flash_attn_triton import flash_attn_func
168
  except:
169
  _installed = False
170
  if not _installed:
171
- raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + 'Note: (1) requires you have CMake and PyTorch already installed.')
 
 
 
 
 
 
172
  check_valid_inputs(query, key, value)
173
- if multiquery:
174
- warnings.warn(DeprecationWarning('The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'))
175
- kv_n_heads = 1
176
- elif kv_n_heads is None:
177
- warnings.warn(DeprecationWarning('Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'))
178
- kv_n_heads = n_heads
179
  if past_key_value is not None:
180
  if len(past_key_value) != 0:
181
  key = torch.cat([past_key_value[0], key], dim=1)
@@ -186,19 +340,27 @@ def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Te
186
  _s_k = max(0, attn_bias.size(3) - key.size(1))
187
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
188
  if dropout_p:
189
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
190
  dropout_p = dropout_p if training else 0.0
191
  if needs_weights:
192
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
193
  if key_padding_mask is not None:
194
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
 
 
 
 
 
 
195
  (b_size, s_k) = key_padding_mask.shape[:2]
196
  if attn_bias is None:
197
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
198
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
199
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
200
- key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
201
- value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
 
 
202
  if kv_n_heads == 1:
203
  key = key.repeat(1, 1, n_heads, 1)
204
  value = value.repeat(1, 1, n_heads, 1)
@@ -206,10 +368,13 @@ def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Te
206
  key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
207
  value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
208
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
209
- attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
 
 
210
  output = attn_output.view(*attn_output.shape[:2], -1)
211
  return (output, None, past_key_value)
212
 
 
213
  class GroupedQueryAttention(nn.Module):
214
  """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
215
 
@@ -220,59 +385,177 @@ class GroupedQueryAttention(nn.Module):
220
  implementation enables user to also use additive bias.
221
  """
222
 
223
- def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  super().__init__()
225
  self.attn_impl = attn_impl
226
  self.clip_qkv = clip_qkv
227
  self.qk_ln = qk_ln
 
228
  self.d_model = d_model
229
  self.n_heads = n_heads
230
  self.kv_n_heads = kv_n_heads
 
231
  self.head_dim = d_model // n_heads
232
  if self.kv_n_heads <= 0:
233
- raise ValueError('kv_n_heads should be greater than zero.')
234
  if self.kv_n_heads > self.n_heads:
235
- raise ValueError('The number of KV heads should be less than or equal to Q heads.')
 
 
236
  if self.n_heads % self.kv_n_heads != 0:
237
- raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
 
 
 
 
238
  self.softmax_scale = softmax_scale
239
  if self.softmax_scale is None:
240
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
241
  self.attn_dropout_p = attn_pdrop
242
- fc_kwargs: dict[str, Any] = {'bias': bias}
243
- if fc_type != 'te':
244
- fc_kwargs['device'] = device
245
- self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
246
- fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
 
 
 
 
 
 
247
  self.Wqkv._fused = (0, fuse_splits)
248
- if self.qk_ln:
249
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
250
- self.q_ln = norm_class(self.d_model, device=device)
251
- self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
252
- if self.attn_impl == 'flash':
 
 
 
253
  self.attn_fn = flash_attn_fn
254
- elif self.attn_impl == 'triton':
255
  self.attn_fn = triton_flash_attn_fn
256
- elif self.attn_impl == 'torch':
257
  self.attn_fn = scaled_multihead_dot_product_attention
258
  else:
259
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
260
- self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
 
 
261
  self.out_proj._is_residual = True
262
 
263
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  qkv = self.Wqkv(x)
265
  if self.clip_qkv:
266
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
267
- (query, key, value) = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
 
 
 
 
 
 
 
268
  key_padding_mask = attention_mask
269
- if self.qk_ln:
 
 
 
 
 
270
  dtype = query.dtype
271
- query = self.q_ln(query).to(dtype)
272
- key = self.k_ln(key).to(dtype)
273
- (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, self.kv_n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  return (self.out_proj(context), attn_weights, past_key_value)
275
 
 
276
  class MultiheadAttention(GroupedQueryAttention):
277
  """Multi-head self attention.
278
 
@@ -280,8 +563,39 @@ class MultiheadAttention(GroupedQueryAttention):
280
  additive bias.
281
  """
282
 
283
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
284
- super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  class MultiQueryAttention(GroupedQueryAttention):
287
  """Multi-Query self attention.
@@ -290,13 +604,52 @@ class MultiQueryAttention(GroupedQueryAttention):
290
  additive bias.
291
  """
292
 
293
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
294
- super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
297
- if attn_impl == 'flash':
 
 
 
 
 
 
 
 
 
298
  return None
299
- elif attn_impl in ['torch', 'triton']:
300
  if alibi:
301
  if (prefix_lm or not causal) or use_sequence_id:
302
  return (1, n_heads, seq_len, seq_len)
@@ -305,34 +658,78 @@ def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, pre
305
  return (1, 1, seq_len, seq_len)
306
  return None
307
  else:
308
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
309
 
310
- def build_attn_bias(attn_impl: str, attn_bias: torch.Tensor, n_heads: int, seq_len: int, causal: bool=False, alibi: bool=False, alibi_bias_max: int=8) -> Optional[torch.Tensor]:
311
- if attn_impl == 'flash':
 
 
 
 
 
 
 
 
312
  return None
313
- elif attn_impl in ['torch', 'triton']:
314
  if alibi:
315
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
316
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
 
 
 
 
 
 
 
 
 
317
  return attn_bias
318
  else:
319
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
320
 
321
- def gen_slopes(n_heads: int, alibi_bias_max: int=8, device: Optional[torch.device]=None) -> torch.Tensor:
 
 
 
 
 
 
322
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
323
  m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
324
  m = m.mul(alibi_bias_max / _n_heads)
325
  slopes = 1.0 / torch.pow(2, m)
326
  if _n_heads != n_heads:
327
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
 
 
328
  return slopes.view(1, n_heads, 1, 1)
329
 
330
- def build_alibi_bias(n_heads: int, seq_len: int, full: bool=False, alibi_bias_max: int=8, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None) -> torch.Tensor:
331
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
 
 
 
 
 
 
 
 
 
 
332
  if full:
333
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
 
 
334
  alibi_bias = alibi_bias.abs().mul(-1)
335
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
336
  alibi_bias = alibi_bias * slopes
337
  return alibi_bias.to(dtype=dtype)
338
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention, 'grouped_query_attention': GroupedQueryAttention}
 
 
 
 
 
 
 
1
  """Attention layers."""
2
+
3
  import math
4
  import warnings
5
+ from typing import Any, Optional
6
  import torch
7
  import torch.nn as nn
8
+ import transformers
9
  from einops import rearrange
10
  from packaging import version
11
  from torch import nn
12
  from .fc import FC_CLASS_REGISTRY
13
  from .norm import NORM_CLASS_REGISTRY
14
 
15
+
16
+ def is_flash_v2_installed(v2_version: str = "2.0.0"):
17
+ assert version.parse(v2_version) >= version.parse("2.0.0")
18
  try:
19
  import flash_attn as flash_attn
20
  except:
21
  return False
22
+ return version.parse(flash_attn.__version__) >= version.parse(v2_version)
23
+
24
 
25
  def is_flash_v1_installed():
26
  try:
27
  import flash_attn as flash_attn
28
  except:
29
  return False
30
+ return version.parse(flash_attn.__version__) < version.parse("2.0.0")
31
+
32
+
33
+ def is_transformers_version_gte(hf_version: str) -> bool:
34
+ return version.parse(transformers.__version__) >= version.parse(hf_version)
35
+
36
+
37
+ def check_alibi_support(attention_impl: str) -> bool:
38
+ return attention_impl != "flash" or is_flash_v2_installed(v2_version="v2.4.2")
39
+
40
 
41
+ if is_flash_v1_installed():
42
+ import transformers
43
+
44
+ transformers.utils.is_flash_attn_available = lambda: False
45
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
46
+
47
+
48
+ def _reset_is_causal(
49
+ num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
50
+ ) -> bool:
51
  if original_is_causal and num_query_tokens != num_key_tokens:
52
  if num_query_tokens != 1:
53
+ raise NotImplementedError(
54
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
55
+ )
56
  else:
57
  return False
58
  return original_is_causal
59
 
60
+
61
  def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
62
  """Perform repeat of kv heads along a particular dimension.
63
 
 
71
  hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
72
  return hidden.reshape(b, s, kv_n_heads * n_rep, d)
73
 
74
+
75
+ def scaled_multihead_dot_product_attention(
76
+ query: torch.Tensor,
77
+ key: torch.Tensor,
78
+ value: torch.Tensor,
79
+ n_heads: int,
80
+ kv_n_heads: int,
81
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
82
+ softmax_scale: Optional[float] = None,
83
+ attn_bias: Optional[torch.Tensor] = None,
84
+ key_padding_mask: Optional[torch.Tensor] = None,
85
+ is_causal: bool = False,
86
+ dropout_p: float = 0.0,
87
+ training: bool = False,
88
+ needs_weights: bool = False,
89
+ ) -> tuple[
90
+ torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
91
+ ]:
92
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
93
+ k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
94
+ v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
95
  if past_key_value is not None:
96
  if len(past_key_value) != 0:
97
  k = torch.cat([past_key_value[0], k], dim=3)
 
109
  _s_q = max(0, attn_bias.size(2) - s_q)
110
  _s_k = max(0, attn_bias.size(3) - s_k)
111
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
112
+ if (
113
+ attn_bias.size(-1) != 1
114
+ and attn_bias.size(-1) != s_k
115
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
116
+ ):
117
+ raise RuntimeError(
118
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
119
+ )
120
  attn_weight = attn_weight + attn_bias
121
  min_val = torch.finfo(q.dtype).min
122
  if key_padding_mask is not None:
123
  if attn_bias is not None:
124
+ warnings.warn(
125
+ "Propagating key_padding_mask to the attention module "
126
+ + "and applying it within the attention module can cause "
127
+ + "unnecessary computation/memory usage. Consider integrating "
128
+ + "into attn_bias once and passing that to each attention "
129
+ + "module instead."
130
+ )
131
+ attn_weight = attn_weight.masked_fill(
132
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val
133
+ )
134
  if is_causal and (not q.size(2) == 1):
135
  s = max(s_q, s_k)
136
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
 
141
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
142
  attn_weight = torch.softmax(attn_weight, dim=-1)
143
  if dropout_p:
144
+ attn_weight = torch.nn.functional.dropout(
145
+ attn_weight, p=dropout_p, training=training, inplace=True
146
+ )
147
  out = attn_weight.to(v.dtype).matmul(v)
148
+ out = rearrange(out, "b h s d -> b s (h d)")
149
  if needs_weights:
150
  return (out, attn_weight, past_key_value)
151
  return (out, None, past_key_value)
152
 
153
+
154
+ def check_valid_inputs(
155
+ *tensors: torch.Tensor, valid_dtypes: Optional[list[torch.dtype]] = None
156
+ ):
157
  if valid_dtypes is None:
158
  valid_dtypes = [torch.float16, torch.bfloat16]
159
  for tensor in tensors:
160
  if tensor.dtype not in valid_dtypes:
161
+ raise TypeError(
162
+ f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
163
+ )
164
  if not tensor.is_cuda:
165
+ raise TypeError(
166
+ f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
167
+ )
168
 
169
+
170
+ def flash_attn_fn(
171
+ query: torch.Tensor,
172
+ key: torch.Tensor,
173
+ value: torch.Tensor,
174
+ n_heads: int,
175
+ kv_n_heads: int,
176
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
177
+ softmax_scale: Optional[float] = None,
178
+ attn_bias: Optional[torch.Tensor] = None,
179
+ key_padding_mask: Optional[torch.Tensor] = None,
180
+ is_causal: bool = False,
181
+ dropout_p: float = 0.0,
182
+ training: bool = False,
183
+ needs_weights: bool = False,
184
+ multiquery: bool = False,
185
+ should_repeat_kv_for_gqa: Optional[bool] = True,
186
+ sliding_window_size: int = -1,
187
+ alibi_slopes: Optional[torch.Tensor] = None,
188
+ flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
189
+ ) -> tuple[
190
+ torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
191
+ ]:
192
+ if key_padding_mask is not None:
193
+ raise ValueError("key_padding_mask should be None for flash attn.")
194
+ del key_padding_mask
195
+ if flash_attn_padding_info is None:
196
+ raise ValueError("flash_attn_padding_info is required for flash attn.")
197
  try:
198
  from flash_attn import bert_padding, flash_attn_interface
199
  except:
200
+ raise RuntimeError("Please install flash-attn==1.0.9 or flash-attn==2.3.6")
201
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
202
  if past_key_value is not None:
203
  if len(past_key_value) != 0:
204
  key = torch.cat([past_key_value[0], key], dim=1)
205
  value = torch.cat([past_key_value[1], value], dim=1)
206
  past_key_value = (key, value)
207
  if attn_bias is not None:
208
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
 
 
 
 
209
  (batch_size, seqlen) = query.shape[:2]
210
+ indices_q = flash_attn_padding_info["indices_q"]
211
+ indices_k = flash_attn_padding_info["indices_k"]
212
+ indices_v = flash_attn_padding_info["indices_v"]
213
+ cu_seqlens_q = flash_attn_padding_info["cu_seqlens_q"]
214
+ cu_seqlens_k = flash_attn_padding_info["cu_seqlens_k"]
215
+ max_seqlen_q = flash_attn_padding_info["max_seqlen_q"]
216
+ max_seqlen_k = flash_attn_padding_info["max_seqlen_k"]
217
+ query_unpad = bert_padding.index_first_axis(
218
+ rearrange(query, "b s ... -> (b s) ..."), indices_q
219
+ )
220
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
221
+ key_unpad = bert_padding.index_first_axis(
222
+ rearrange(key, "b s ... -> (b s) ..."), indices_k
223
+ )
224
+ key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=kv_n_heads)
225
+ value_unpad = bert_padding.index_first_axis(
226
+ rearrange(value, "b s ... -> (b s) ..."), indices_v
227
+ )
228
+ value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=kv_n_heads)
229
+ if (
230
+ kv_n_heads < n_heads
231
+ and (not is_flash_v2_installed())
232
+ and (not should_repeat_kv_for_gqa)
233
+ ):
234
+ raise ValueError(
235
+ "For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2."
236
+ )
237
+ if should_repeat_kv_for_gqa:
238
+ if kv_n_heads == 1:
239
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
240
+ value_unpad = value_unpad.expand(
241
+ value_unpad.size(0), n_heads, value_unpad.size(-1)
242
+ )
243
+ elif kv_n_heads < n_heads:
244
+ key_unpad = repeat_kv_for_gqa(
245
+ key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
246
+ n_heads // kv_n_heads,
247
+ ).view(key_unpad.size(0), n_heads, -1)
248
+ value_unpad = repeat_kv_for_gqa(
249
+ value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
250
+ n_heads // kv_n_heads,
251
+ ).view(value_unpad.size(0), n_heads, -1)
252
  dropout_p = dropout_p if training else 0.0
253
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
254
  if is_flash_v1_installed():
255
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
256
+ q=query_unpad,
257
+ k=key_unpad,
258
+ v=value_unpad,
259
+ cu_seqlens_q=cu_seqlens_q,
260
+ cu_seqlens_k=cu_seqlens_k,
261
+ max_seqlen_q=max_seqlen_q,
262
+ max_seqlen_k=max_seqlen_k,
263
+ dropout_p=dropout_p,
264
+ softmax_scale=softmax_scale,
265
+ causal=reset_is_causal,
266
+ return_attn_probs=needs_weights,
267
+ )
268
  elif is_flash_v2_installed():
269
+ alibi_kwargs = {}
270
+ if check_alibi_support("flash"):
271
+ alibi_kwargs = {"alibi_slopes": alibi_slopes}
272
+ elif alibi_slopes is not None:
273
+ raise ValueError("alibi_slopes is only supported for flash-attn>=2.4.2")
274
+ output_unpad = flash_attn_interface.flash_attn_varlen_func(
275
+ q=query_unpad,
276
+ k=key_unpad,
277
+ v=value_unpad,
278
+ cu_seqlens_q=cu_seqlens_q,
279
+ cu_seqlens_k=cu_seqlens_k,
280
+ max_seqlen_q=max_seqlen_q,
281
+ max_seqlen_k=max_seqlen_k,
282
+ dropout_p=dropout_p,
283
+ softmax_scale=softmax_scale,
284
+ causal=reset_is_causal,
285
+ return_attn_probs=needs_weights,
286
+ window_size=(sliding_window_size, sliding_window_size),
287
+ **alibi_kwargs,
288
+ )
289
  else:
290
+ raise RuntimeError("flash-attn==1.0.9 or flash-attn==2.4.2 is required.")
291
+ output = bert_padding.pad_input(
292
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
293
+ )
294
  return (output, None, past_key_value)
295
 
296
+
297
+ def triton_flash_attn_fn(
298
+ query: torch.Tensor,
299
+ key: torch.Tensor,
300
+ value: torch.Tensor,
301
+ n_heads: int,
302
+ kv_n_heads: int,
303
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
304
+ softmax_scale: Optional[float] = None,
305
+ attn_bias: Optional[torch.Tensor] = None,
306
+ key_padding_mask: Optional[torch.Tensor] = None,
307
+ is_causal: bool = False,
308
+ dropout_p: float = 0.0,
309
+ training: bool = False,
310
+ needs_weights: bool = False,
311
+ ) -> tuple[
312
+ torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
313
+ ]:
314
  try:
315
  from .flash_attn_triton import flash_attn_func
316
  except:
317
  _installed = False
318
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
319
  _installed = True
320
  try:
321
  from flash_attn.flash_attn_triton import flash_attn_func
322
  except:
323
  _installed = False
324
  if not _installed:
325
+ raise RuntimeError(
326
+ "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU "
327
+ + "and `pip install .[gpu]` if installing from llm-foundry source or "
328
+ + "`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` "
329
+ + "if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). "
330
+ + "Note: (1) requires you have CMake and PyTorch already installed."
331
+ )
332
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
333
  if past_key_value is not None:
334
  if len(past_key_value) != 0:
335
  key = torch.cat([past_key_value[0], key], dim=1)
 
340
  _s_k = max(0, attn_bias.size(3) - key.size(1))
341
  attn_bias = attn_bias[:, :, _s_q:, _s_k:]
342
  if dropout_p:
343
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
344
  dropout_p = dropout_p if training else 0.0
345
  if needs_weights:
346
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
347
  if key_padding_mask is not None:
348
+ warnings.warn(
349
+ "Propagating key_padding_mask to the attention module "
350
+ + "and applying it within the attention module can cause "
351
+ + "unnecessary computation/memory usage. Consider integrating "
352
+ + "into attn_bias once and passing that to each attention "
353
+ + "module instead."
354
+ )
355
  (b_size, s_k) = key_padding_mask.shape[:2]
356
  if attn_bias is None:
357
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
358
+ attn_bias = attn_bias.masked_fill(
359
+ ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
360
+ )
361
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
362
+ key = rearrange(key, "b s (h d) -> b s h d", h=kv_n_heads)
363
+ value = rearrange(value, "b s (h d) -> b s h d", h=kv_n_heads)
364
  if kv_n_heads == 1:
365
  key = key.repeat(1, 1, n_heads, 1)
366
  value = value.repeat(1, 1, n_heads, 1)
 
368
  key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
369
  value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
370
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
371
+ attn_output = flash_attn_func(
372
+ query, key, value, attn_bias, reset_is_causal, softmax_scale
373
+ )
374
  output = attn_output.view(*attn_output.shape[:2], -1)
375
  return (output, None, past_key_value)
376
 
377
+
378
  class GroupedQueryAttention(nn.Module):
379
  """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
380
 
 
385
  implementation enables user to also use additive bias.
386
  """
387
 
388
+ def __init__(
389
+ self,
390
+ d_model: int,
391
+ n_heads: int,
392
+ kv_n_heads: int,
393
+ attn_impl: str = "triton",
394
+ clip_qkv: Optional[float] = None,
395
+ qk_ln: bool = False,
396
+ qk_gn: bool = False,
397
+ softmax_scale: Optional[float] = None,
398
+ attn_pdrop: float = 0.0,
399
+ norm_type: str = "low_precision_layernorm",
400
+ fc_type: str = "torch",
401
+ device: Optional[str] = None,
402
+ bias: bool = True,
403
+ sliding_window_size: int = -1,
404
+ ):
405
  super().__init__()
406
  self.attn_impl = attn_impl
407
  self.clip_qkv = clip_qkv
408
  self.qk_ln = qk_ln
409
+ self.qk_gn = qk_gn
410
  self.d_model = d_model
411
  self.n_heads = n_heads
412
  self.kv_n_heads = kv_n_heads
413
+ self.sliding_window_size = sliding_window_size
414
  self.head_dim = d_model // n_heads
415
  if self.kv_n_heads <= 0:
416
+ raise ValueError("kv_n_heads should be greater than zero.")
417
  if self.kv_n_heads > self.n_heads:
418
+ raise ValueError(
419
+ "The number of KV heads should be less than or equal to Q heads."
420
+ )
421
  if self.n_heads % self.kv_n_heads != 0:
422
+ raise ValueError(
423
+ "Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads."
424
+ )
425
+ if qk_ln and qk_gn:
426
+ raise ValueError("Only one of qk_ln and qk_gn can be set to True.")
427
  self.softmax_scale = softmax_scale
428
  if self.softmax_scale is None:
429
  self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
430
  self.attn_dropout_p = attn_pdrop
431
+ fc_kwargs: dict[str, Any] = {"bias": bias}
432
+ if fc_type != "te":
433
+ fc_kwargs["device"] = device
434
+ self.Wqkv = FC_CLASS_REGISTRY[fc_type](
435
+ self.d_model,
436
+ self.d_model + 2 * self.kv_n_heads * self.head_dim,
437
+ **fc_kwargs,
438
+ )
439
+ fuse_splits = [
440
+ i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)
441
+ ]
442
  self.Wqkv._fused = (0, fuse_splits)
443
+ if self.qk_ln or self.qk_gn:
444
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
445
+ norm_size = self.head_dim if qk_gn else d_model
446
+ self.q_ln = norm_class(norm_size, device=device)
447
+ if qk_ln:
448
+ norm_size = self.head_dim * kv_n_heads
449
+ self.k_ln = norm_class(norm_size, device=device)
450
+ if self.attn_impl == "flash":
451
  self.attn_fn = flash_attn_fn
452
+ elif self.attn_impl == "triton":
453
  self.attn_fn = triton_flash_attn_fn
454
+ elif self.attn_impl == "torch":
455
  self.attn_fn = scaled_multihead_dot_product_attention
456
  else:
457
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
458
+ self.out_proj = FC_CLASS_REGISTRY[fc_type](
459
+ self.d_model, self.d_model, **fc_kwargs
460
+ )
461
  self.out_proj._is_residual = True
462
 
463
+ def forward(
464
+ self,
465
+ x: torch.Tensor,
466
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
467
+ attn_bias: Optional[torch.Tensor] = None,
468
+ attention_mask: Optional[torch.Tensor] = None,
469
+ rotary_emb_w_meta_info: Optional[dict] = None,
470
+ is_causal: bool = True,
471
+ needs_weights: bool = False,
472
+ alibi_slopes: Optional[torch.Tensor] = None,
473
+ flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
474
+ ) -> tuple[
475
+ torch.Tensor,
476
+ Optional[torch.Tensor],
477
+ Optional[tuple[torch.Tensor, torch.Tensor]],
478
+ ]:
479
  qkv = self.Wqkv(x)
480
  if self.clip_qkv:
481
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
482
+ (query, key, value) = qkv.split(
483
+ [
484
+ self.d_model,
485
+ self.kv_n_heads * self.head_dim,
486
+ self.kv_n_heads * self.head_dim,
487
+ ],
488
+ dim=2,
489
+ )
490
  key_padding_mask = attention_mask
491
+ if self.qk_ln or self.qk_gn:
492
+ (q_shape, k_shape) = (query.shape, key.shape)
493
+ if self.qk_gn:
494
+ (b, s) = query.shape[:2]
495
+ query = query.view(b, s, self.n_heads, -1)
496
+ key = key.view(b, s, self.kv_n_heads, -1)
497
  dtype = query.dtype
498
+ query = self.q_ln(query).to(dtype).view(q_shape)
499
+ key = self.k_ln(key).to(dtype).view(k_shape)
500
+ if rotary_emb_w_meta_info is not None:
501
+ rotary_emb = rotary_emb_w_meta_info["rotary_emb"]
502
+ seq_len = rotary_emb_w_meta_info["seq_len"]
503
+ offset_info = rotary_emb_w_meta_info["offset_info"]
504
+ (bsz, seqlen) = query.shape[:2]
505
+ query = query.view(bsz, seqlen, -1, self.head_dim)
506
+ key = key.view(bsz, seqlen, -1, self.head_dim)
507
+ if rotary_emb_w_meta_info["impl"] == "dail":
508
+ value = value.view(bsz, seqlen, -1, self.head_dim)
509
+ kv = torch.stack([key, value], dim=2)
510
+ (query, kv) = rotary_emb(
511
+ query, kv, seqlen_offset=offset_info, max_seqlen=seq_len
512
+ )
513
+ [key, value] = torch.unbind(kv, dim=2)
514
+ value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
515
+ elif rotary_emb_w_meta_info["impl"] == "hf":
516
+ (cos, sin) = rotary_emb(value, seq_len)
517
+ if is_transformers_version_gte("4.36"):
518
+ (query, key) = apply_rotary_pos_emb(
519
+ query, key, cos, sin, offset_info, unsqueeze_dim=2
520
+ )
521
+ else:
522
+ query = query.transpose(1, 2)
523
+ key = key.transpose(1, 2)
524
+ (query, key) = apply_rotary_pos_emb(
525
+ query, key, cos, sin, offset_info
526
+ )
527
+ query = query.transpose(1, 2)
528
+ key = key.transpose(1, 2)
529
+ query = query.view(bsz, seqlen, self.d_model)
530
+ key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
531
+ extra_attn_kwargs = {}
532
+ if self.attn_impl == "flash":
533
+ key_padding_mask = None
534
+ extra_attn_kwargs = {
535
+ "should_repeat_kv_for_gqa": not is_flash_v2_installed(),
536
+ "sliding_window_size": self.sliding_window_size,
537
+ "alibi_slopes": alibi_slopes,
538
+ "flash_attn_padding_info": flash_attn_padding_info,
539
+ }
540
+ (context, attn_weights, past_key_value) = self.attn_fn(
541
+ query,
542
+ key,
543
+ value,
544
+ self.n_heads,
545
+ self.kv_n_heads,
546
+ past_key_value=past_key_value,
547
+ softmax_scale=self.softmax_scale,
548
+ attn_bias=attn_bias,
549
+ key_padding_mask=key_padding_mask,
550
+ is_causal=is_causal,
551
+ dropout_p=self.attn_dropout_p,
552
+ training=self.training,
553
+ needs_weights=needs_weights,
554
+ **extra_attn_kwargs,
555
+ )
556
  return (self.out_proj(context), attn_weights, past_key_value)
557
 
558
+
559
  class MultiheadAttention(GroupedQueryAttention):
560
  """Multi-head self attention.
561
 
 
563
  additive bias.
564
  """
565
 
566
+ def __init__(
567
+ self,
568
+ d_model: int,
569
+ n_heads: int,
570
+ attn_impl: str = "triton",
571
+ clip_qkv: Optional[float] = None,
572
+ qk_ln: bool = False,
573
+ qk_gn: bool = False,
574
+ softmax_scale: Optional[float] = None,
575
+ attn_pdrop: float = 0.0,
576
+ norm_type: str = "low_precision_layernorm",
577
+ fc_type: str = "torch",
578
+ device: Optional[str] = None,
579
+ bias: bool = True,
580
+ sliding_window_size: int = -1,
581
+ ):
582
+ super().__init__(
583
+ d_model=d_model,
584
+ n_heads=n_heads,
585
+ kv_n_heads=n_heads,
586
+ attn_impl=attn_impl,
587
+ clip_qkv=clip_qkv,
588
+ qk_ln=qk_ln,
589
+ qk_gn=qk_gn,
590
+ softmax_scale=softmax_scale,
591
+ attn_pdrop=attn_pdrop,
592
+ norm_type=norm_type,
593
+ fc_type=fc_type,
594
+ device=device,
595
+ bias=bias,
596
+ sliding_window_size=sliding_window_size,
597
+ )
598
+
599
 
600
  class MultiQueryAttention(GroupedQueryAttention):
601
  """Multi-Query self attention.
 
604
  additive bias.
605
  """
606
 
607
+ def __init__(
608
+ self,
609
+ d_model: int,
610
+ n_heads: int,
611
+ attn_impl: str = "triton",
612
+ clip_qkv: Optional[float] = None,
613
+ qk_ln: bool = False,
614
+ qk_gn: bool = False,
615
+ softmax_scale: Optional[float] = None,
616
+ attn_pdrop: float = 0.0,
617
+ norm_type: str = "low_precision_layernorm",
618
+ fc_type: str = "torch",
619
+ device: Optional[str] = None,
620
+ bias: bool = True,
621
+ sliding_window_size: int = -1,
622
+ ):
623
+ super().__init__(
624
+ d_model=d_model,
625
+ n_heads=n_heads,
626
+ kv_n_heads=1,
627
+ attn_impl=attn_impl,
628
+ clip_qkv=clip_qkv,
629
+ qk_ln=qk_ln,
630
+ qk_gn=qk_gn,
631
+ softmax_scale=softmax_scale,
632
+ attn_pdrop=attn_pdrop,
633
+ norm_type=norm_type,
634
+ fc_type=fc_type,
635
+ device=device,
636
+ bias=bias,
637
+ sliding_window_size=sliding_window_size,
638
+ )
639
 
640
+
641
+ def attn_bias_shape(
642
+ attn_impl: str,
643
+ n_heads: int,
644
+ seq_len: int,
645
+ alibi: bool,
646
+ prefix_lm: bool,
647
+ causal: bool,
648
+ use_sequence_id: bool,
649
+ ) -> Optional[tuple[int, int, int, int]]:
650
+ if attn_impl == "flash":
651
  return None
652
+ elif attn_impl in ["torch", "triton"]:
653
  if alibi:
654
  if (prefix_lm or not causal) or use_sequence_id:
655
  return (1, n_heads, seq_len, seq_len)
 
658
  return (1, 1, seq_len, seq_len)
659
  return None
660
  else:
661
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
662
+
663
 
664
+ def build_attn_bias(
665
+ attn_impl: str,
666
+ attn_bias: torch.Tensor,
667
+ n_heads: int,
668
+ seq_len: int,
669
+ causal: bool = False,
670
+ alibi: bool = False,
671
+ alibi_bias_max: int = 8,
672
+ ) -> Optional[torch.Tensor]:
673
+ if attn_impl == "flash":
674
  return None
675
+ elif attn_impl in ["torch", "triton"]:
676
  if alibi:
677
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
678
+ attn_bias = attn_bias.add(
679
+ build_alibi_bias(
680
+ n_heads,
681
+ seq_len,
682
+ full=not causal,
683
+ alibi_bias_max=alibi_bias_max,
684
+ device=device,
685
+ dtype=dtype,
686
+ )
687
+ )
688
  return attn_bias
689
  else:
690
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
691
 
692
+
693
+ def gen_slopes(
694
+ n_heads: int,
695
+ alibi_bias_max: int = 8,
696
+ device: Optional[torch.device] = None,
697
+ return_1d: bool = False,
698
+ ) -> torch.Tensor:
699
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
700
  m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
701
  m = m.mul(alibi_bias_max / _n_heads)
702
  slopes = 1.0 / torch.pow(2, m)
703
  if _n_heads != n_heads:
704
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
705
+ if return_1d:
706
+ return slopes
707
  return slopes.view(1, n_heads, 1, 1)
708
 
709
+
710
+ def build_alibi_bias(
711
+ n_heads: int,
712
+ seq_len: int,
713
+ full: bool = False,
714
+ alibi_bias_max: int = 8,
715
+ device: Optional[torch.device] = None,
716
+ dtype: Optional[torch.dtype] = None,
717
+ ) -> torch.Tensor:
718
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
719
+ 1, 1, 1, seq_len
720
+ )
721
  if full:
722
+ alibi_bias = alibi_bias - torch.arange(
723
+ 1 - seq_len, 1, dtype=torch.int32, device=device
724
+ ).view(1, 1, seq_len, 1)
725
  alibi_bias = alibi_bias.abs().mul(-1)
726
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
727
  alibi_bias = alibi_bias * slopes
728
  return alibi_bias.to(dtype=dtype)
729
+
730
+
731
+ ATTN_CLASS_REGISTRY = {
732
+ "multihead_attention": MultiheadAttention,
733
+ "multiquery_attention": MultiQueryAttention,
734
+ "grouped_query_attention": GroupedQueryAttention,
735
+ }
blocks.py CHANGED
@@ -1,4 +1,5 @@
1
  """GPT Blocks used for the GPT Model."""
 
2
  from typing import Any, Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
@@ -6,8 +7,37 @@ from .attention import ATTN_CLASS_REGISTRY
6
  from .ffn import FFN_CLASS_REGISTRY, build_ffn
7
  from .norm import NORM_CLASS_REGISTRY
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class MPTBlock(nn.Module):
 
11
  def __init__(
12
  self,
13
  d_model: int,
@@ -20,21 +50,11 @@ class MPTBlock(nn.Module):
20
  fc_type: str = "torch",
21
  device: Optional[str] = None,
22
  no_bias: bool = False,
 
23
  **kwargs: Any
24
  ):
25
  if attn_config is None:
26
- attn_config = {
27
- "attn_type": "multihead_attention",
28
- "attn_pdrop": 0.0,
29
- "attn_impl": "triton",
30
- "qk_ln": False,
31
- "clip_qkv": None,
32
- "softmax_scale": None,
33
- "prefix_lm": False,
34
- "attn_uses_sequence_id": False,
35
- "alibi": False,
36
- "alibi_bias_max": 8,
37
- }
38
  if ffn_config is None:
39
  ffn_config = {"ffn_type": "mptmlp"}
40
  del kwargs
@@ -48,6 +68,11 @@ class MPTBlock(nn.Module):
48
  "alibi",
49
  "attn_uses_sequence_id",
50
  "alibi_bias_max",
 
 
 
 
 
51
  }
52
  attn_config_subset_for_attn_class = {
53
  k: v
@@ -75,15 +100,19 @@ class MPTBlock(nn.Module):
75
  )
76
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
77
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
 
78
 
79
  def forward(
80
  self,
81
  x: torch.Tensor,
82
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
83
  attn_bias: Optional[torch.Tensor] = None,
 
84
  attention_mask: Optional[torch.ByteTensor] = None,
85
  is_causal: bool = True,
86
  output_attentions: bool = False,
 
 
87
  ) -> Tuple[
88
  torch.Tensor,
89
  Optional[torch.Tensor],
@@ -94,14 +123,25 @@ class MPTBlock(nn.Module):
94
  a,
95
  past_key_value=past_key_value,
96
  attn_bias=attn_bias,
 
97
  attention_mask=attention_mask,
98
  is_causal=is_causal,
99
  needs_weights=output_attentions,
 
 
100
  )
101
  x = x + self.resid_attn_dropout(b)
102
  m = x
103
  if self.norm_2 is not None:
104
  m = self.norm_2(x)
 
 
 
 
 
105
  n = self.ffn(m)
 
 
 
106
  x = x + self.resid_ffn_dropout(n)
107
  return (x, attn_weights, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
+
3
  from typing import Any, Dict, Optional, Tuple
4
  import torch
5
  import torch.nn as nn
 
7
  from .ffn import FFN_CLASS_REGISTRY, build_ffn
8
  from .norm import NORM_CLASS_REGISTRY
9
 
10
+ try:
11
+ from flash_attn.bert_padding import unpad_input, pad_input
12
+ except:
13
+ (unpad_input, pad_input) = (None, None)
14
+ attn_config_defaults: Dict = {
15
+ "attn_type": "multihead_attention",
16
+ "attn_pdrop": 0.0,
17
+ "attn_impl": "flash",
18
+ "qk_ln": True,
19
+ "qk_gn": False,
20
+ "clip_qkv": None,
21
+ "softmax_scale": None,
22
+ "prefix_lm": False,
23
+ "attn_uses_sequence_id": False,
24
+ "sliding_window_size": -1,
25
+ "alibi": False,
26
+ "alibi_bias_max": 8,
27
+ "rope": False,
28
+ "rope_theta": 10000,
29
+ "rope_impl": "dail",
30
+ "rope_dail_config": {
31
+ "type": "original",
32
+ "pos_idx_in_fp32": True,
33
+ "xpos_scale_base": 512,
34
+ },
35
+ "rope_hf_config": {"type": "no_scaling", "factor": 1.0},
36
+ }
37
+
38
 
39
  class MPTBlock(nn.Module):
40
+
41
  def __init__(
42
  self,
43
  d_model: int,
 
50
  fc_type: str = "torch",
51
  device: Optional[str] = None,
52
  no_bias: bool = False,
53
+ use_pad_tok_in_ffn: bool = True,
54
  **kwargs: Any
55
  ):
56
  if attn_config is None:
57
+ attn_config = attn_config_defaults
 
 
 
 
 
 
 
 
 
 
 
58
  if ffn_config is None:
59
  ffn_config = {"ffn_type": "mptmlp"}
60
  del kwargs
 
68
  "alibi",
69
  "attn_uses_sequence_id",
70
  "alibi_bias_max",
71
+ "rope",
72
+ "rope_theta",
73
+ "rope_impl",
74
+ "rope_dail_config",
75
+ "rope_hf_config",
76
  }
77
  attn_config_subset_for_attn_class = {
78
  k: v
 
100
  )
101
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
102
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
103
+ self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
104
 
105
  def forward(
106
  self,
107
  x: torch.Tensor,
108
  past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
109
  attn_bias: Optional[torch.Tensor] = None,
110
+ rotary_emb_w_meta_info: Optional[Dict] = None,
111
  attention_mask: Optional[torch.ByteTensor] = None,
112
  is_causal: bool = True,
113
  output_attentions: bool = False,
114
+ alibi_slopes: Optional[torch.Tensor] = None,
115
+ flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
116
  ) -> Tuple[
117
  torch.Tensor,
118
  Optional[torch.Tensor],
 
123
  a,
124
  past_key_value=past_key_value,
125
  attn_bias=attn_bias,
126
+ rotary_emb_w_meta_info=rotary_emb_w_meta_info,
127
  attention_mask=attention_mask,
128
  is_causal=is_causal,
129
  needs_weights=output_attentions,
130
+ alibi_slopes=alibi_slopes,
131
+ flash_attn_padding_info=flash_attn_padding_info,
132
  )
133
  x = x + self.resid_attn_dropout(b)
134
  m = x
135
  if self.norm_2 is not None:
136
  m = self.norm_2(x)
137
+ (batch_size, seq_len) = m.size()[:2]
138
+ indices = None
139
+ if not self.use_pad_tok_in_ffn:
140
+ assert unpad_input is not None
141
+ (m, indices, _, _) = unpad_input(m, attention_mask)
142
  n = self.ffn(m)
143
+ if not self.use_pad_tok_in_ffn:
144
+ assert pad_input is not None
145
+ n = pad_input(n, indices, batch_size, seq_len)
146
  x = x + self.resid_ffn_dropout(n)
147
  return (x, attn_weights, past_key_value)
config.json CHANGED
@@ -12,7 +12,21 @@
12
  "attn_uses_sequence_id": false,
13
  "clip_qkv": null,
14
  "prefix_lm": false,
 
15
  "qk_ln": true,
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "softmax_scale": null
17
  },
18
  "auto_map": {
@@ -55,5 +69,6 @@
55
  "torch_dtype": "bfloat16",
56
  "transformers_version": "4.37.2",
57
  "use_cache": false,
 
58
  "vocab_size": 256000
59
  }
 
12
  "attn_uses_sequence_id": false,
13
  "clip_qkv": null,
14
  "prefix_lm": false,
15
+ "qk_gn": false,
16
  "qk_ln": true,
17
+ "rope": false,
18
+ "rope_dail_config": {
19
+ "pos_idx_in_fp32": true,
20
+ "type": "original",
21
+ "xpos_scale_base": 512
22
+ },
23
+ "rope_hf_config": {
24
+ "factor": 1.0,
25
+ "type": "no_scaling"
26
+ },
27
+ "rope_impl": "dail",
28
+ "rope_theta": 10000,
29
+ "sliding_window_size": -1,
30
  "softmax_scale": null
31
  },
32
  "auto_map": {
 
69
  "torch_dtype": "bfloat16",
70
  "transformers_version": "4.37.2",
71
  "use_cache": false,
72
+ "use_pad_tok_in_ffn": true,
73
  "vocab_size": 256000
74
  }
configuration_mpt.py CHANGED
@@ -1,22 +1,63 @@
1
  """A HuggingFace-style model configuration."""
 
2
  import warnings
3
  from typing import Any, Dict, Optional, Union
4
  from transformers import PretrainedConfig
5
- attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
6
- ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'}
7
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class MPTConfig(PretrainedConfig):
10
- model_type = 'mpt'
11
 
12
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, ffn_config: Dict=ffn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, fc_type: str='torch', verbose: Optional[int]=None, **kwargs: Any):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """The MPT configuration class.
14
 
15
  Args:
16
  d_model (int): The size of the embedding dimension of the model.
17
  n_heads (int): The number of attention heads.
18
  n_layers (int): The number of layers in the model.
19
- expansion_ratio (int): The ratio of the up/down scale in the ffn.
20
  max_seq_len (int): The maximum sequence length of the model.
21
  vocab_size (int): The size of the vocabulary.
22
  resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
@@ -27,6 +68,7 @@ class MPTConfig(PretrainedConfig):
27
  attn_pdrop (float): The dropout probability for the attention layers.
28
  attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
29
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
 
30
  clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
31
  this value.
32
  softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
@@ -38,15 +80,25 @@ class MPTConfig(PretrainedConfig):
38
  When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
39
  which sub-sequence each token belongs to.
40
  Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
 
41
  alibi (bool): Whether to use the alibi bias instead of position embeddings.
42
  alibi_bias_max (int): The maximum value of the alibi bias.
 
 
 
 
 
 
 
 
 
 
43
  kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
44
  ffn_config (Dict): A dictionary used to configure the model's ffn module:
45
- ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
46
  init_device (str): The device to use for parameter initialization.
47
  logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
48
  no_bias (bool): Whether to use bias in all layers.
49
- verbose (int): The verbosity level. 0 is silent.
50
  embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
51
  norm_type (str): choose type of norm to use
52
  use_cache (bool): Whether or not the model should return the last key/values attentions
@@ -66,6 +118,8 @@ class MPTConfig(PretrainedConfig):
66
  ---
67
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
68
  fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
 
 
69
  """
70
  self.d_model = d_model
71
  self.n_heads = n_heads
@@ -86,55 +140,183 @@ class MPTConfig(PretrainedConfig):
86
  self.use_cache = use_cache
87
  self.init_config = init_config
88
  self.fc_type = fc_type
89
- if verbose is not None:
90
- warnings.warn(DeprecationWarning('verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'))
91
- if 'name' in kwargs:
92
- del kwargs['name']
93
- if 'loss_fn' in kwargs:
94
- del kwargs['loss_fn']
95
- if self.attn_config.get('alibi', False):
96
  self.learned_pos_emb = False
97
- warnings.warn(f'alibi is turned on, setting `learned_pos_emb` to `False.`')
98
- super().__init__(**kwargs)
 
 
99
  self._validate_config()
100
 
101
- def _set_config_defaults(self, config: Dict[str, Any], config_defaults: Dict[str, Any]) -> Dict[str, Any]:
102
- for (k, v) in config_defaults.items():
 
 
103
  if k not in config:
104
  config[k] = v
 
 
 
 
105
  return config
106
 
107
  def _validate_config(self) -> None:
108
- self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
109
- self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
110
- self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
 
 
 
 
 
 
111
  if self.d_model % self.n_heads != 0:
112
- raise ValueError('d_model must be divisible by n_heads')
113
- if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
114
- raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
115
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
 
 
 
 
 
 
 
 
 
 
 
116
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
117
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
118
- raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
119
- if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
120
- raise NotImplementedError('alibi only implemented with torch and triton attention.')
121
- if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
122
- raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
124
- raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
125
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
126
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
127
- if self.init_config.get('name', None) is None:
128
- raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
129
- if not self.learned_pos_emb and (not self.attn_config['alibi']):
130
- warnings.warn(f'Positional information not being provided to the model using either learned_pos_emb or alibi.')
131
- if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
 
 
 
 
 
 
 
 
 
 
 
 
132
  try:
133
  import transformer_engine.pytorch as te
 
134
  del te
135
  except:
136
- raise ImportError('TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' + 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156')
137
- if self.ffn_config['ffn_type'] == 'mptmlp':
138
- self.ffn_config['fc_type'] = self.fc_type
139
- elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
140
- self.ffn_config['bias'] = not self.no_bias
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """A HuggingFace-style model configuration."""
2
+
3
  import warnings
4
  from typing import Any, Dict, Optional, Union
5
  from transformers import PretrainedConfig
6
+ from .attention import check_alibi_support, is_flash_v1_installed, is_flash_v2_installed
7
+ from .blocks import attn_config_defaults
8
+ from .fc import FC_CLASS_REGISTRY
9
+ from .norm import LPLayerNorm
10
+ from .ffn import FFN_CLASS_REGISTRY
11
+ from .warnings import VersionedDeprecationWarning
12
+
13
+ ffn_config_defaults: Dict = {"ffn_type": "mptmlp"}
14
+ init_config_defaults: Dict = {
15
+ "name": "kaiming_normal_",
16
+ "fan_mode": "fan_in",
17
+ "init_nonlinearity": "relu",
18
+ "init_div_is_residual": True,
19
+ "emb_init_std": None,
20
+ "emb_init_uniform_lim": None,
21
+ "init_std": None,
22
+ "init_gain": 0.0,
23
+ }
24
+
25
 
26
  class MPTConfig(PretrainedConfig):
27
+ model_type = "mpt"
28
 
29
+ def __init__(
30
+ self,
31
+ d_model: int = 2048,
32
+ n_heads: int = 16,
33
+ n_layers: int = 24,
34
+ expansion_ratio: Union[int, float] = 4,
35
+ max_seq_len: int = 2048,
36
+ vocab_size: int = 50368,
37
+ resid_pdrop: float = 0.0,
38
+ emb_pdrop: float = 0.0,
39
+ learned_pos_emb: bool = True,
40
+ attn_config: Dict = attn_config_defaults,
41
+ ffn_config: Dict = ffn_config_defaults,
42
+ init_device: str = "cpu",
43
+ logit_scale: Optional[Union[float, str]] = None,
44
+ no_bias: bool = False,
45
+ embedding_fraction: float = 1.0,
46
+ norm_type: str = "low_precision_layernorm",
47
+ use_cache: bool = False,
48
+ init_config: Dict = init_config_defaults,
49
+ fc_type: str = "torch",
50
+ tie_word_embeddings: bool = True,
51
+ use_pad_tok_in_ffn: bool = True,
52
+ **kwargs: Any,
53
+ ):
54
  """The MPT configuration class.
55
 
56
  Args:
57
  d_model (int): The size of the embedding dimension of the model.
58
  n_heads (int): The number of attention heads.
59
  n_layers (int): The number of layers in the model.
60
+ expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn.
61
  max_seq_len (int): The maximum sequence length of the model.
62
  vocab_size (int): The size of the vocabulary.
63
  resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
 
68
  attn_pdrop (float): The dropout probability for the attention layers.
69
  attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
70
  qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
71
+ qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
72
  clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
73
  this value.
74
  softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
 
80
  When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
81
  which sub-sequence each token belongs to.
82
  Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
83
+ sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher.
84
  alibi (bool): Whether to use the alibi bias instead of position embeddings.
85
  alibi_bias_max (int): The maximum value of the alibi bias.
86
+ rope (bool): Whether to use rotary positional embeddings.
87
+ rope_theta (int): The base frequency for rope.
88
+ rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
89
+ rope_dail_config (Dict): The configuration for the dail implementation of rope.
90
+ type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf).
91
+ pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding.
92
+ xpos_scale_base (float): The scale base for XPos (if using XPos).
93
+ rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length).
94
+ type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
95
+ factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
96
  kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
97
  ffn_config (Dict): A dictionary used to configure the model's ffn module:
98
+ ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
99
  init_device (str): The device to use for parameter initialization.
100
  logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
101
  no_bias (bool): Whether to use bias in all layers.
 
102
  embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
103
  norm_type (str): choose type of norm to use
104
  use_cache (bool): Whether or not the model should return the last key/values attentions
 
118
  ---
119
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
120
  fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
121
+ tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
122
+ use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
123
  """
124
  self.d_model = d_model
125
  self.n_heads = n_heads
 
140
  self.use_cache = use_cache
141
  self.init_config = init_config
142
  self.fc_type = fc_type
143
+ self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
144
+ if "name" in kwargs:
145
+ del kwargs["name"]
146
+ if "loss_fn" in kwargs:
147
+ del kwargs["loss_fn"]
148
+ if self.attn_config.get("alibi", False) or self.attn_config.get("rope", False):
 
149
  self.learned_pos_emb = False
150
+ warnings.warn(
151
+ f"alibi or rope is turned on, setting `learned_pos_emb` to `False.`"
152
+ )
153
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
154
  self._validate_config()
155
 
156
+ def _set_config_defaults(
157
+ self, config: Dict[str, Any], config_defaults: Dict[str, Any]
158
+ ) -> Dict[str, Any]:
159
+ for k, v in config_defaults.items():
160
  if k not in config:
161
  config[k] = v
162
+ elif isinstance(v, dict):
163
+ config[k] = self._set_config_defaults(
164
+ config[k] if config[k] is not None else {}, v
165
+ )
166
  return config
167
 
168
  def _validate_config(self) -> None:
169
+ self.attn_config = self._set_config_defaults(
170
+ self.attn_config, attn_config_defaults
171
+ )
172
+ self.ffn_config = self._set_config_defaults(
173
+ self.ffn_config, ffn_config_defaults
174
+ )
175
+ self.init_config = self._set_config_defaults(
176
+ self.init_config, init_config_defaults
177
+ )
178
  if self.d_model % self.n_heads != 0:
179
+ raise ValueError("d_model must be divisible by n_heads")
180
+ if any(
181
+ (
182
+ prob < 0 or prob > 1
183
+ for prob in [
184
+ self.attn_config["attn_pdrop"],
185
+ self.resid_pdrop,
186
+ self.emb_pdrop,
187
+ ]
188
+ )
189
+ ):
190
+ raise ValueError(
191
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
192
+ )
193
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
194
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
195
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
196
+ "torch",
197
+ "triton",
198
+ ]:
199
+ raise NotImplementedError(
200
+ "prefix_lm only implemented with torch and triton attention."
201
+ )
202
+ if self.attn_config["attn_impl"] == "flash" and is_flash_v1_installed():
203
+ warnings.warn(
204
+ VersionedDeprecationWarning(
205
+ 'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.',
206
+ remove_version="0.6.0",
207
+ )
208
+ )
209
+ if self.attn_config["attn_impl"] == "triton" and (
210
+ not self.attn_config["prefix_lm"]
211
+ ):
212
+ warnings.warn(
213
+ UserWarning(
214
+ 'If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'
215
+ )
216
+ )
217
+ if self.attn_config["alibi"] and (
218
+ not check_alibi_support(self.attn_config["attn_impl"])
219
+ ):
220
+ raise NotImplementedError(
221
+ "alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention."
222
+ )
223
+ if self.attn_config["attn_uses_sequence_id"] and (
224
+ not (
225
+ self.attn_config["attn_impl"] in ["torch", "triton"]
226
+ or (
227
+ self.attn_config["attn_impl"] == "flash"
228
+ and is_flash_v2_installed(v2_version="v2.1.2")
229
+ )
230
+ )
231
+ ):
232
+ raise NotImplementedError(
233
+ "attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention."
234
+ )
235
+ if self.attn_config["rope"] and self.attn_config["rope_impl"] not in [
236
+ "dail",
237
+ "hf",
238
+ ]:
239
+ raise ValueError(
240
+ 'If rope is being used then rope_impl should be either "dail", or "hf".'
241
+ )
242
+ if (
243
+ self.attn_config["rope"]
244
+ and self.attn_config["rope_impl"] == "hf"
245
+ and (
246
+ self.attn_config["rope_hf_config"]["type"]
247
+ not in ["no_scaling", "linear", "dynamic"]
248
+ )
249
+ ):
250
+ raise ValueError(
251
+ 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".'
252
+ )
253
+ if self.attn_config["rope"] and self.attn_config["rope_impl"] == "dail":
254
+ if self.attn_config["rope_dail_config"]["type"] not in ["original", "xpos"]:
255
+ raise ValueError(
256
+ 'If using the dail implementation of rope, the type should be one of "original" or "xpos".'
257
+ )
258
+ if not is_flash_v2_installed(v2_version="2.0.1"):
259
+ raise ImportError(
260
+ "If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support"
261
+ )
262
+ if self.attn_config["sliding_window_size"] != -1 and (
263
+ not (
264
+ self.attn_config["attn_impl"] == "flash"
265
+ and is_flash_v2_installed(v2_version="v2.3.0")
266
+ )
267
+ ):
268
+ raise NotImplementedError(
269
+ "sliding window only implemented with flash attention v2.3.0 or higher."
270
+ )
271
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
272
+ raise ValueError(
273
+ "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
274
+ )
275
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
276
+ raise ValueError(
277
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
278
+ )
279
+ if self.init_config.get("name", None) is None:
280
+ raise ValueError(
281
+ f"self.init_config={self.init_config!r} 'name' needs to be set."
282
+ )
283
+ if not (
284
+ self.learned_pos_emb
285
+ or self.attn_config["alibi"]
286
+ or self.attn_config["rope"]
287
+ ):
288
+ warnings.warn(
289
+ f"Positional information not being provided to the model using either learned_pos_emb or alibi or rope."
290
+ )
291
+ if self.fc_type == "te" or self.ffn_config["ffn_type"] == "te_ln_mlp":
292
  try:
293
  import transformer_engine.pytorch as te
294
+
295
  del te
296
  except:
297
+ raise ImportError(
298
+ "TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. "
299
+ + "The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n"
300
+ + "pip install flash-attn==1.0.6 --no-build-isolation \n"
301
+ + "pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156"
302
+ )
303
+ if self.ffn_config["ffn_type"] == "mptgeglu":
304
+ raise ValueError(
305
+ 'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. '
306
+ + "See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details."
307
+ )
308
+ elif self.ffn_config["ffn_type"] in ["mptmlp", "mptglu"]:
309
+ self.ffn_config["fc_type"] = self.fc_type
310
+ elif self.ffn_config["ffn_type"] == "te_ln_mlp":
311
+ self.ffn_config["bias"] = not self.no_bias
312
+ if "ffn_act_fn" in self.ffn_config.keys():
313
+ raise ValueError(
314
+ f"Transformer Engine block does not support custom activation functions."
315
+ )
316
+ if not self.use_pad_tok_in_ffn:
317
+ try:
318
+ from flash_attn.bert_padding import unpad_input, pad_input
319
+ except:
320
+ raise ImportError(
321
+ "In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6"
322
+ )
custom_embedding.py CHANGED
@@ -2,9 +2,10 @@ import torch.nn as nn
2
  import torch.nn.functional as F
3
  from torch import Tensor
4
 
 
5
  class SharedEmbedding(nn.Embedding):
6
 
7
- def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
8
  if unembed:
9
  return F.linear(input, self.weight)
10
- return super().forward(input)
 
2
  import torch.nn.functional as F
3
  from torch import Tensor
4
 
5
+
6
  class SharedEmbedding(nn.Embedding):
7
 
8
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
9
  if unembed:
10
  return F.linear(input, self.weight)
11
+ return super().forward(input)
fc.py CHANGED
@@ -1,7 +1,9 @@
1
  from torch import nn
2
- FC_CLASS_REGISTRY = {'torch': nn.Linear}
 
3
  try:
4
  import transformer_engine.pytorch as te
5
- FC_CLASS_REGISTRY['te'] = te.Linear
 
6
  except:
7
- pass
 
1
  from torch import nn
2
+
3
+ FC_CLASS_REGISTRY = {"torch": nn.Linear}
4
  try:
5
  import transformer_engine.pytorch as te
6
+
7
+ FC_CLASS_REGISTRY["te"] = te.Linear
8
  except:
9
+ pass
ffn.py CHANGED
@@ -1,39 +1,173 @@
1
- """GPT Blocks used for the GPT Model."""
2
- from typing import Any, Optional
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from .fc import FC_CLASS_REGISTRY
 
6
  try:
7
  import transformer_engine.pytorch as te
8
  except:
9
  te = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class MPTMLP(nn.Module):
12
 
13
- def __init__(self, d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, bias: bool=True):
 
 
 
 
 
 
 
 
 
14
  super().__init__()
15
- fc_kwargs: dict[str, Any] = {'bias': bias}
16
- if fc_type != 'te':
17
- fc_kwargs['device'] = device
18
- self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, expansion_ratio * d_model, **fc_kwargs)
19
- self.act = nn.GELU(approximate='none')
20
- self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * d_model, d_model, **fc_kwargs)
 
 
 
 
 
 
 
21
  self.down_proj._is_residual = True
22
 
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
  return self.down_proj(self.act(self.up_proj(x)))
25
- FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if te is not None:
27
  te.LayerNormMLP._has_norm = True
28
- FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
 
29
 
30
- def build_ffn(d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
31
- ffn_type = kwargs.pop('ffn_type')
32
- if ffn_type == 'mptmlp':
 
 
 
 
 
 
 
 
 
33
  if len(kwargs) > 0:
34
- raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
35
- return MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device, bias=bias)
36
- elif ffn_type == 'te_ln_mlp':
 
 
 
 
 
 
 
 
 
 
37
  assert te is not None
38
- return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, bias=bias, **kwargs)
39
- raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
 
 
 
 
 
 
 
 
 
 
1
+ """MPT Blocks used for the MPT Model."""
2
+
3
+ import logging
4
+ from copy import deepcopy
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
  import torch
8
  import torch.nn as nn
9
  from .fc import FC_CLASS_REGISTRY
10
+
11
  try:
12
  import transformer_engine.pytorch as te
13
  except:
14
  te = None
15
+ log = logging.getLogger(__name__)
16
+ _FFN_ACT_FN_DEFAULT = {"name": "gelu", "approximate": "none"}
17
+
18
+
19
+ def resolve_ffn_act_fn(
20
+ config: Optional[dict] = None,
21
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
22
+ """Resolve the activation function for the feed-forward network.
23
+ Args:
24
+ config (Optional[dict]): The configuration dictionary for the activation function.
25
+ The dict config must specify the 'name' of a torch.nn.functional activation
26
+ function. All of other key values pairs are bound to the function as a partial.
27
+ Returns:
28
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
29
+ """
30
+ if config is None:
31
+ config = _FFN_ACT_FN_DEFAULT
32
+ config = deepcopy(config)
33
+ name = config.pop("name")
34
+ if not hasattr(torch.nn.functional, name):
35
+ raise ValueError(f"Unrecognised activation function name ({name}).")
36
+ act = getattr(torch.nn.functional, name)
37
+ return partial(act, **config)
38
+
39
+
40
+ _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
41
+
42
+
43
+ def resolve_ffn_hidden_size(
44
+ d_model: int,
45
+ expansion_ratio: Union[int, float],
46
+ ffn_hidden_size: Optional[int] = None,
47
+ ) -> int:
48
+ """Resolve the hidden size of the feed-forward network.
49
+ Args:
50
+ d_model (int): The dimension of the input and output of the feed-forward network.
51
+ expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
52
+ ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
53
+ Returns:
54
+ int: The hidden size of the feed-forward network.
55
+ """
56
+ if ffn_hidden_size is not None:
57
+ log.info(
58
+ f"`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified."
59
+ )
60
+ else:
61
+ ffn_hidden_size = int(d_model * expansion_ratio)
62
+ if ffn_hidden_size != d_model * expansion_ratio:
63
+ raise ValueError(
64
+ f"`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r})."
65
+ )
66
+ return ffn_hidden_size
67
+
68
 
69
  class MPTMLP(nn.Module):
70
 
71
+ def __init__(
72
+ self,
73
+ d_model: int,
74
+ expansion_ratio: Union[int, float],
75
+ fc_type: str = "torch",
76
+ ffn_hidden_size: Optional[int] = None,
77
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
78
+ device: Optional[str] = None,
79
+ bias: bool = True,
80
+ ):
81
  super().__init__()
82
+ ffn_hidden_size = resolve_ffn_hidden_size(
83
+ d_model, expansion_ratio, ffn_hidden_size
84
+ )
85
+ self.fc_kwargs: dict[str, Any] = {"bias": bias}
86
+ if fc_type != "te":
87
+ self.fc_kwargs["device"] = device
88
+ self.up_proj = FC_CLASS_REGISTRY[fc_type](
89
+ d_model, ffn_hidden_size, **self.fc_kwargs
90
+ )
91
+ self.act = act_fn
92
+ self.down_proj = FC_CLASS_REGISTRY[fc_type](
93
+ ffn_hidden_size, d_model, **self.fc_kwargs
94
+ )
95
  self.down_proj._is_residual = True
96
 
97
  def forward(self, x: torch.Tensor) -> torch.Tensor:
98
  return self.down_proj(self.act(self.up_proj(x)))
99
+
100
+
101
+ class MPTGLU(MPTMLP):
102
+
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ expansion_ratio: Union[int, float],
107
+ fc_type: str = "torch",
108
+ ffn_hidden_size: Optional[int] = None,
109
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
110
+ device: Optional[str] = None,
111
+ bias: bool = True,
112
+ ):
113
+ super().__init__(
114
+ d_model=d_model,
115
+ expansion_ratio=expansion_ratio,
116
+ fc_type=fc_type,
117
+ ffn_hidden_size=ffn_hidden_size,
118
+ act_fn=act_fn,
119
+ device=device,
120
+ bias=bias,
121
+ )
122
+ self.gate_proj = FC_CLASS_REGISTRY[fc_type](
123
+ d_model, self.up_proj.out_features, **self.fc_kwargs
124
+ )
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
128
+
129
+
130
+ FFN_CLASS_REGISTRY = {"mptmlp": MPTMLP, "mptglu": MPTGLU}
131
  if te is not None:
132
  te.LayerNormMLP._has_norm = True
133
+ FFN_CLASS_REGISTRY["te_ln_mlp"] = te.LayerNormMLP
134
+
135
 
136
+ def build_ffn(
137
+ d_model: int,
138
+ expansion_ratio: Union[int, float],
139
+ fc_type: str = "torch",
140
+ ffn_hidden_size: Optional[int] = None,
141
+ ffn_act_fn: Optional[dict] = None,
142
+ device: Optional[str] = None,
143
+ bias: bool = True,
144
+ **kwargs: Any,
145
+ ) -> nn.Module:
146
+ ffn_type = kwargs.pop("ffn_type")
147
+ if ffn_type in ["mptmlp", "mptglu"]:
148
  if len(kwargs) > 0:
149
+ raise ValueError(
150
+ f"MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}"
151
+ )
152
+ return FFN_CLASS_REGISTRY[ffn_type](
153
+ d_model=d_model,
154
+ expansion_ratio=expansion_ratio,
155
+ fc_type=fc_type,
156
+ act_fn=resolve_ffn_act_fn(ffn_act_fn),
157
+ ffn_hidden_size=ffn_hidden_size,
158
+ device=device,
159
+ bias=bias,
160
+ )
161
+ elif ffn_type == "te_ln_mlp":
162
  assert te is not None
163
+ ffn_hidden_size = resolve_ffn_hidden_size(
164
+ d_model, expansion_ratio, ffn_hidden_size
165
+ )
166
+ if ffn_act_fn is not None:
167
+ raise ValueError(
168
+ f"Transformer Engine block does not support custom activation functions."
169
+ )
170
+ return te.LayerNormMLP(
171
+ hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs
172
+ )
173
+ raise ValueError(f"ffn_type={ffn_type!r} not recognized.")
flash_attn_triton.py CHANGED
@@ -1,17 +1,14 @@
1
  """
2
  Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
  update imports to use 'triton_pre_mlir'
4
-
5
  *Experimental* implementation of FlashAttention in Triton.
6
  Tested with triton==2.0.0.dev20221202.
7
  Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
  other than 64:
9
  https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
  We'll update this implementation with the new Triton backend once this is fixed.
11
-
12
  We use the FlashAttention implementation from Phil Tillet a starting point.
13
  https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
-
15
  Changes:
16
  - Implement both causal and non-causal attention.
17
  - Implement both self-attention and cross-attention.
@@ -22,7 +19,6 @@ Changes:
22
  - Make the backward for d=128 much faster by reducing register spilling.
23
  - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
  small batch size * nheads.
25
-
26
  Caution:
27
  - This is an *experimental* implementation. The forward pass should be quite robust but
28
  I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
@@ -32,7 +28,6 @@ I'm not 100% sure that the backward pass doesn't have race conditions (due to th
32
  "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
  for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
  that there are none left for other head dimensions.
35
-
36
  Differences between this Triton version and the CUDA version:
37
  - Triton version doesn't support dropout.
38
  - Triton forward is generally faster than CUDA forward, while Triton backward is
@@ -41,14 +36,61 @@ than CUDA forward + backward.
41
  - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
  - Triton version supports attention bias, while CUDA version doesn't.
43
  """
 
44
  import math
45
  import torch
46
  import triton_pre_mlir as triton
47
  import triton_pre_mlir.language as tl
48
 
49
- @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})
 
 
 
 
 
 
 
50
  @triton.jit
51
- def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  start_m = tl.program_id(0)
53
  off_hb = tl.program_id(1)
54
  off_b = off_hb // nheads
@@ -56,16 +98,36 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
56
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
57
  offs_n = tl.arange(0, BLOCK_N)
58
  offs_d = tl.arange(0, BLOCK_HEADDIM)
59
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
60
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
61
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
62
- if BIAS_TYPE == 'vector':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
64
- elif BIAS_TYPE == 'matrix':
65
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
 
 
 
 
 
66
  t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
67
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
68
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
69
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
70
  if EVEN_M & EVEN_N:
71
  if EVEN_HEADDIM:
@@ -75,7 +137,11 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
75
  elif EVEN_HEADDIM:
76
  q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
77
  else:
78
- q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
79
  end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
80
  for start_n in range(0, end_n, BLOCK_N):
81
  start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -83,29 +149,51 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
83
  if EVEN_HEADDIM:
84
  k = tl.load(k_ptrs + start_n * stride_kn)
85
  else:
86
- k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
 
 
 
 
87
  elif EVEN_HEADDIM:
88
- k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
 
 
 
 
89
  else:
90
- k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
 
91
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
92
  qk += tl.dot(q, k, trans_b=True)
93
  if not EVEN_N:
94
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float('-inf'))
95
  if IS_CAUSAL:
96
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float('-inf'))
97
- if BIAS_TYPE != 'none':
98
- if BIAS_TYPE == 'vector':
 
 
99
  if EVEN_N:
100
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
101
  else:
102
- bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
 
 
103
  bias = bias[None, :]
104
- elif BIAS_TYPE == 'matrix':
105
  if EVEN_M & EVEN_N:
106
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
107
  else:
108
- bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
 
 
 
 
 
109
  qk = qk * softmax_scale + bias
110
  m_ij = tl.maximum(tl.max(qk, 1), lse_i)
111
  p = tl.exp(qk - m_ij[:, None])
@@ -121,11 +209,24 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
121
  if EVEN_HEADDIM:
122
  v = tl.load(v_ptrs + start_n * stride_vn)
123
  else:
124
- v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
 
 
 
 
125
  elif EVEN_HEADDIM:
126
- v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
 
 
 
 
127
  else:
128
- v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
 
129
  p = p.to(v.dtype)
130
  acc_o += tl.dot(p, v)
131
  m_i = m_ij
@@ -140,7 +241,12 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
140
  lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
141
  tl.store(lse_ptrs, lse_i)
142
  offs_d = tl.arange(0, BLOCK_HEADDIM)
143
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
 
 
 
 
 
144
  if EVEN_M:
145
  if EVEN_HEADDIM:
146
  tl.store(out_ptrs, acc_o)
@@ -149,23 +255,73 @@ def _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_q
149
  elif EVEN_HEADDIM:
150
  tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
151
  else:
152
- tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
 
 
 
 
 
153
 
154
  @triton.jit
155
- def _bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  start_m = tl.program_id(0)
157
  off_hb = tl.program_id(1)
158
  off_b = off_hb // nheads
159
  off_h = off_hb % nheads
160
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
161
  offs_d = tl.arange(0, BLOCK_HEADDIM)
162
- o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
163
- do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  delta = tl.sum(o * do, axis=1)
165
  tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
166
 
 
167
  @triton.jit
168
- def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
169
  if EVEN_N & EVEN_M:
170
  if EVEN_HEADDIM:
171
  tl.store(dv_ptrs, dv)
@@ -177,11 +333,49 @@ def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim
177
  tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
178
  tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
179
  else:
180
- tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
181
- tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
 
 
 
 
 
182
 
183
  @triton.jit
184
- def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
186
  offs_qm = begin_m + tl.arange(0, BLOCK_M)
187
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -192,16 +386,28 @@ def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, so
192
  v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
193
  do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
194
  dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
195
- if BIAS_TYPE == 'vector':
196
  b_ptrs = Bias + offs_n
197
- elif BIAS_TYPE == 'matrix':
198
  b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
199
  dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
200
  dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
201
  if begin_m >= seqlen_q:
202
  dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
203
  dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
204
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
 
 
 
 
 
 
 
 
 
 
 
 
205
  return
206
  if EVEN_N & EVEN_M:
207
  if EVEN_HEADDIM:
@@ -214,8 +420,16 @@ def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, so
214
  k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
215
  v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
216
  else:
217
- k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
218
- v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
 
 
 
 
219
  num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
220
  for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
221
  start_m = tl.multiple_of(start_m, BLOCK_M)
@@ -225,37 +439,52 @@ def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, so
225
  elif EVEN_HEADDIM:
226
  q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
227
  else:
228
- q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
229
  qk = tl.dot(q, k, trans_b=True)
230
  if not EVEN_N:
231
- qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
232
  if IS_CAUSAL:
233
- qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float('-inf'))
234
- if BIAS_TYPE != 'none':
235
  tl.debug_barrier()
236
- if BIAS_TYPE == 'vector':
237
  if EVEN_N:
238
  bias = tl.load(b_ptrs).to(tl.float32)
239
  else:
240
- bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
 
 
241
  bias = bias[None, :]
242
- elif BIAS_TYPE == 'matrix':
243
  if EVEN_M & EVEN_N:
244
  bias = tl.load(b_ptrs).to(tl.float32)
245
  else:
246
- bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
 
 
 
 
 
247
  qk = qk * softmax_scale + bias
248
  if not EVEN_M & EVEN_HEADDIM:
249
  tl.debug_barrier()
250
  lse_i = tl.load(LSE + offs_m_curr)
251
- if BIAS_TYPE == 'none':
252
  p = tl.exp(qk * softmax_scale - lse_i[:, None])
253
  else:
254
  p = tl.exp(qk - lse_i[:, None])
255
  if EVEN_M & EVEN_HEADDIM:
256
  do = tl.load(do_ptrs)
257
  else:
258
- do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
 
 
 
 
259
  dv += tl.dot(p.to(do.dtype), do, trans_a=True)
260
  if not EVEN_M & EVEN_HEADDIM:
261
  tl.debug_barrier()
@@ -269,17 +498,39 @@ def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, so
269
  tl.debug_barrier()
270
  if not ATOMIC_ADD:
271
  if EVEN_M & EVEN_HEADDIM:
272
- dq = tl.load(dq_ptrs, eviction_policy='evict_last')
273
  dq += tl.dot(ds, k)
274
- tl.store(dq_ptrs, dq, eviction_policy='evict_last')
275
  elif EVEN_HEADDIM:
276
- dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy='evict_last')
 
 
 
 
 
277
  dq += tl.dot(ds, k)
278
- tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy='evict_last')
 
 
 
 
 
279
  else:
280
- dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy='evict_last')
 
 
 
 
 
 
281
  dq += tl.dot(ds, k)
282
- tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy='evict_last')
 
 
 
 
 
 
283
  else:
284
  dq = tl.dot(ds, k)
285
  if EVEN_M & EVEN_HEADDIM:
@@ -287,23 +538,122 @@ def _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, so
287
  elif EVEN_HEADDIM:
288
  tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
289
  else:
290
- tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
 
 
 
 
 
291
  dq_ptrs += BLOCK_M * stride_dqm
292
  q_ptrs += BLOCK_M * stride_qm
293
  do_ptrs += BLOCK_M * stride_dom
294
- if BIAS_TYPE == 'matrix':
295
  b_ptrs += BLOCK_M * stride_bm
296
  dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
297
  dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
298
- _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  def init_to_zero(name):
301
  return lambda nargs: nargs[name].zero_()
302
 
303
- @triton.autotune(configs=[triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'SEQUENCE_PARALLEL': True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ'))], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'])
304
- @triton.heuristics({'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM']})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  @triton.jit
306
- def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  off_hb = tl.program_id(1)
308
  off_b = off_hb // nheads
309
  off_h = off_hb % nheads
@@ -314,30 +664,97 @@ def _bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb,
314
  DQ += off_b * stride_dqb + off_h * stride_dqh
315
  DK += off_b * stride_dkb + off_h * stride_dkh
316
  DV += off_b * stride_dvb + off_h * stride_dvh
317
- if BIAS_TYPE != 'none':
318
  Bias += off_b * stride_bb + off_h * stride_bh
319
  D += off_hb * seqlen_q_rounded
320
  LSE += off_hb * seqlen_q_rounded
321
  if not SEQUENCE_PARALLEL:
322
  num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
323
  for start_n in range(0, num_block_n):
324
- _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  else:
326
  start_n = tl.program_id(0)
327
- _bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
330
  (batch, seqlen_q, nheads, d) = q.shape
331
  (_, seqlen_k, _, _) = k.shape
332
  assert k.shape == (batch, seqlen_k, nheads, d)
333
  assert v.shape == (batch, seqlen_k, nheads, d)
334
- assert d <= 128, 'FlashAttention only support head dimensions up to 128'
335
- assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
336
- assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
337
  assert q.is_cuda and k.is_cuda and v.is_cuda
338
  softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
339
  has_bias = bias is not None
340
- bias_type = 'none'
341
  if has_bias:
342
  assert bias.dtype in [q.dtype, torch.float]
343
  assert bias.is_cuda
@@ -345,25 +762,72 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
345
  if bias.stride(-1) != 1:
346
  bias = bias.contiguous()
347
  if bias.shape[2:] == (1, seqlen_k):
348
- bias_type = 'vector'
349
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
350
- bias_type = 'matrix'
351
  else:
352
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
 
 
353
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
354
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
 
 
355
  seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
356
- lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
357
- tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
 
 
 
 
358
  o = torch.empty_like(q)
359
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
360
  BLOCK = 128
361
  num_warps = 4 if d <= 64 else 8
362
- grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
363
- _fwd_kernel[grid](q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  return (o, lse, softmax_scale)
365
 
366
- def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
 
 
 
367
  if do.stride(-1) != 1:
368
  do = do.contiguous()
369
  (batch, seqlen_q, nheads, d) = q.shape
@@ -377,40 +841,115 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
377
  dq_accum = torch.empty_like(q, dtype=torch.float32)
378
  delta = torch.empty_like(lse)
379
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
380
- grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
381
- _bwd_preprocess_do_o_dot[grid](o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  has_bias = bias is not None
383
- bias_type = 'none'
384
  if has_bias:
385
  assert bias.dtype in [q.dtype, torch.float]
386
  assert bias.is_cuda
387
  assert bias.dim() == 4
388
  assert bias.stride(-1) == 1
389
  if bias.shape[2:] == (1, seqlen_k):
390
- bias_type = 'vector'
391
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
392
- bias_type = 'matrix'
393
  else:
394
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)')
 
 
395
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
396
- bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
397
- grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N']) if META['SEQUENCE_PARALLEL'] else 1, batch * nheads)
398
- _bwd_kernel[grid](q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, bias_type, causal, BLOCK_HEADDIM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  dq.copy_(dq_accum)
400
 
 
401
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
402
 
403
  @staticmethod
404
  def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
405
  """
406
- qkv: (batch, seqlen, 3, nheads, headdim)
407
- bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
408
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
409
- ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
410
  """
411
  if qkv.stride(-1) != 1:
412
  qkv = qkv.contiguous()
413
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
 
 
 
 
 
 
 
414
  ctx.save_for_backward(qkv, o, lse, bias)
415
  ctx.causal = causal
416
  return o
@@ -418,26 +957,51 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
418
  @staticmethod
419
  def backward(ctx, do):
420
  (qkv, o, lse, bias) = ctx.saved_tensors
421
- assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
 
 
422
  with torch.inference_mode():
423
  dqkv = torch.empty_like(qkv)
424
- _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  return (dqkv, None, None, None)
 
 
426
  flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
427
 
 
428
  class FlashAttnKVPackedFunc(torch.autograd.Function):
429
 
430
  @staticmethod
431
  def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
432
  """
433
- q: (batch, seqlen_q, nheads, headdim)
434
- kv: (batch, seqlen_k, 2, nheads, headdim)
435
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
436
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
437
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
438
  """
439
  (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
440
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
 
 
 
 
 
 
 
441
  ctx.save_for_backward(q, kv, o, lse, bias)
442
  ctx.causal = causal
443
  return o
@@ -446,27 +1010,47 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
446
  def backward(ctx, do):
447
  (q, kv, o, lse, bias) = ctx.saved_tensors
448
  if len(ctx.needs_input_grad) >= 3:
449
- assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
 
 
450
  with torch.inference_mode():
451
  dq = torch.empty_like(q)
452
  dkv = torch.empty_like(kv)
453
- _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  return (dq, dkv, None, None, None)
 
 
455
  flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
456
 
 
457
  class FlashAttnFunc(torch.autograd.Function):
458
 
459
  @staticmethod
460
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
461
  """
462
- q: (batch_size, seqlen_q, nheads, headdim)
463
- k, v: (batch_size, seqlen_k, nheads, headdim)
464
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
465
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
466
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
467
  """
468
  (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
469
- (o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
 
 
470
  ctx.save_for_backward(q, k, v, o, lse, bias)
471
  ctx.causal = causal
472
  return o
@@ -474,11 +1058,28 @@ class FlashAttnFunc(torch.autograd.Function):
474
  @staticmethod
475
  def backward(ctx, do):
476
  (q, k, v, o, lse, bias) = ctx.saved_tensors
477
- assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
 
 
478
  with torch.inference_mode():
479
  dq = torch.empty_like(q)
480
  dk = torch.empty_like(k)
481
  dv = torch.empty_like(v)
482
- _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  return (dq, dk, dv, None, None, None)
484
- flash_attn_func = FlashAttnFunc.apply
 
 
 
1
  """
2
  Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
  update imports to use 'triton_pre_mlir'
 
4
  *Experimental* implementation of FlashAttention in Triton.
5
  Tested with triton==2.0.0.dev20221202.
6
  Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
7
  other than 64:
8
  https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
9
  We'll update this implementation with the new Triton backend once this is fixed.
 
10
  We use the FlashAttention implementation from Phil Tillet a starting point.
11
  https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
 
12
  Changes:
13
  - Implement both causal and non-causal attention.
14
  - Implement both self-attention and cross-attention.
 
19
  - Make the backward for d=128 much faster by reducing register spilling.
20
  - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
21
  small batch size * nheads.
 
22
  Caution:
23
  - This is an *experimental* implementation. The forward pass should be quite robust but
24
  I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
 
28
  "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
29
  for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
30
  that there are none left for other head dimensions.
 
31
  Differences between this Triton version and the CUDA version:
32
  - Triton version doesn't support dropout.
33
  - Triton forward is generally faster than CUDA forward, while Triton backward is
 
36
  - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
37
  - Triton version supports attention bias, while CUDA version doesn't.
38
  """
39
+
40
  import math
41
  import torch
42
  import triton_pre_mlir as triton
43
  import triton_pre_mlir.language as tl
44
 
45
+
46
+ @triton.heuristics(
47
+ {
48
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
49
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
50
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
51
+ }
52
+ )
53
  @triton.jit
54
+ def _fwd_kernel(
55
+ Q,
56
+ K,
57
+ V,
58
+ Bias,
59
+ Out,
60
+ Lse,
61
+ TMP,
62
+ softmax_scale,
63
+ stride_qb,
64
+ stride_qh,
65
+ stride_qm,
66
+ stride_kb,
67
+ stride_kh,
68
+ stride_kn,
69
+ stride_vb,
70
+ stride_vh,
71
+ stride_vn,
72
+ stride_bb,
73
+ stride_bh,
74
+ stride_bm,
75
+ stride_ob,
76
+ stride_oh,
77
+ stride_om,
78
+ nheads,
79
+ seqlen_q,
80
+ seqlen_k,
81
+ seqlen_q_rounded,
82
+ headdim,
83
+ CACHE_KEY_SEQLEN_Q,
84
+ CACHE_KEY_SEQLEN_K,
85
+ BIAS_TYPE: tl.constexpr,
86
+ IS_CAUSAL: tl.constexpr,
87
+ BLOCK_HEADDIM: tl.constexpr,
88
+ EVEN_M: tl.constexpr,
89
+ EVEN_N: tl.constexpr,
90
+ EVEN_HEADDIM: tl.constexpr,
91
+ BLOCK_M: tl.constexpr,
92
+ BLOCK_N: tl.constexpr,
93
+ ):
94
  start_m = tl.program_id(0)
95
  off_hb = tl.program_id(1)
96
  off_b = off_hb // nheads
 
98
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
99
  offs_n = tl.arange(0, BLOCK_N)
100
  offs_d = tl.arange(0, BLOCK_HEADDIM)
101
+ q_ptrs = (
102
+ Q
103
+ + off_b * stride_qb
104
+ + off_h * stride_qh
105
+ + (offs_m[:, None] * stride_qm + offs_d[None, :])
106
+ )
107
+ k_ptrs = (
108
+ K
109
+ + off_b * stride_kb
110
+ + off_h * stride_kh
111
+ + (offs_n[:, None] * stride_kn + offs_d[None, :])
112
+ )
113
+ v_ptrs = (
114
+ V
115
+ + off_b * stride_vb
116
+ + off_h * stride_vh
117
+ + (offs_n[:, None] * stride_vn + offs_d[None, :])
118
+ )
119
+ if BIAS_TYPE == "vector":
120
  b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
121
+ elif BIAS_TYPE == "matrix":
122
+ b_ptrs = (
123
+ Bias
124
+ + off_b * stride_bb
125
+ + off_h * stride_bh
126
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
127
+ )
128
  t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
129
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
130
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
131
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
132
  if EVEN_M & EVEN_N:
133
  if EVEN_HEADDIM:
 
137
  elif EVEN_HEADDIM:
138
  q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
139
  else:
140
+ q = tl.load(
141
+ q_ptrs,
142
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
143
+ other=0.0,
144
+ )
145
  end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
146
  for start_n in range(0, end_n, BLOCK_N):
147
  start_n = tl.multiple_of(start_n, BLOCK_N)
 
149
  if EVEN_HEADDIM:
150
  k = tl.load(k_ptrs + start_n * stride_kn)
151
  else:
152
+ k = tl.load(
153
+ k_ptrs + start_n * stride_kn,
154
+ mask=offs_d[None, :] < headdim,
155
+ other=0.0,
156
+ )
157
  elif EVEN_HEADDIM:
158
+ k = tl.load(
159
+ k_ptrs + start_n * stride_kn,
160
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
161
+ other=0.0,
162
+ )
163
  else:
164
+ k = tl.load(
165
+ k_ptrs + start_n * stride_kn,
166
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
167
+ & (offs_d[None, :] < headdim),
168
+ other=0.0,
169
+ )
170
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
171
  qk += tl.dot(q, k, trans_b=True)
172
  if not EVEN_N:
173
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
174
  if IS_CAUSAL:
175
+ qk += tl.where(
176
+ offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")
177
+ )
178
+ if BIAS_TYPE != "none":
179
+ if BIAS_TYPE == "vector":
180
  if EVEN_N:
181
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
182
  else:
183
+ bias = tl.load(
184
+ b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0
185
+ ).to(tl.float32)
186
  bias = bias[None, :]
187
+ elif BIAS_TYPE == "matrix":
188
  if EVEN_M & EVEN_N:
189
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
190
  else:
191
+ bias = tl.load(
192
+ b_ptrs + start_n,
193
+ mask=(offs_m[:, None] < seqlen_q)
194
+ & ((start_n + offs_n)[None, :] < seqlen_k),
195
+ other=0.0,
196
+ ).to(tl.float32)
197
  qk = qk * softmax_scale + bias
198
  m_ij = tl.maximum(tl.max(qk, 1), lse_i)
199
  p = tl.exp(qk - m_ij[:, None])
 
209
  if EVEN_HEADDIM:
210
  v = tl.load(v_ptrs + start_n * stride_vn)
211
  else:
212
+ v = tl.load(
213
+ v_ptrs + start_n * stride_vn,
214
+ mask=offs_d[None, :] < headdim,
215
+ other=0.0,
216
+ )
217
  elif EVEN_HEADDIM:
218
+ v = tl.load(
219
+ v_ptrs + start_n * stride_vn,
220
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
221
+ other=0.0,
222
+ )
223
  else:
224
+ v = tl.load(
225
+ v_ptrs + start_n * stride_vn,
226
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
227
+ & (offs_d[None, :] < headdim),
228
+ other=0.0,
229
+ )
230
  p = p.to(v.dtype)
231
  acc_o += tl.dot(p, v)
232
  m_i = m_ij
 
241
  lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
242
  tl.store(lse_ptrs, lse_i)
243
  offs_d = tl.arange(0, BLOCK_HEADDIM)
244
+ out_ptrs = (
245
+ Out
246
+ + off_b * stride_ob
247
+ + off_h * stride_oh
248
+ + (offs_m[:, None] * stride_om + offs_d[None, :])
249
+ )
250
  if EVEN_M:
251
  if EVEN_HEADDIM:
252
  tl.store(out_ptrs, acc_o)
 
255
  elif EVEN_HEADDIM:
256
  tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
257
  else:
258
+ tl.store(
259
+ out_ptrs,
260
+ acc_o,
261
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
262
+ )
263
+
264
 
265
  @triton.jit
266
+ def _bwd_preprocess_do_o_dot(
267
+ Out,
268
+ DO,
269
+ Delta,
270
+ stride_ob,
271
+ stride_oh,
272
+ stride_om,
273
+ stride_dob,
274
+ stride_doh,
275
+ stride_dom,
276
+ nheads,
277
+ seqlen_q,
278
+ seqlen_q_rounded,
279
+ headdim,
280
+ BLOCK_M: tl.constexpr,
281
+ BLOCK_HEADDIM: tl.constexpr,
282
+ ):
283
  start_m = tl.program_id(0)
284
  off_hb = tl.program_id(1)
285
  off_b = off_hb // nheads
286
  off_h = off_hb % nheads
287
  offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
288
  offs_d = tl.arange(0, BLOCK_HEADDIM)
289
+ o = tl.load(
290
+ Out
291
+ + off_b * stride_ob
292
+ + off_h * stride_oh
293
+ + offs_m[:, None] * stride_om
294
+ + offs_d[None, :],
295
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
296
+ other=0.0,
297
+ ).to(tl.float32)
298
+ do = tl.load(
299
+ DO
300
+ + off_b * stride_dob
301
+ + off_h * stride_doh
302
+ + offs_m[:, None] * stride_dom
303
+ + offs_d[None, :],
304
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
305
+ other=0.0,
306
+ ).to(tl.float32)
307
  delta = tl.sum(o * do, axis=1)
308
  tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
309
 
310
+
311
  @triton.jit
312
+ def _bwd_store_dk_dv(
313
+ dk_ptrs,
314
+ dv_ptrs,
315
+ dk,
316
+ dv,
317
+ offs_n,
318
+ offs_d,
319
+ seqlen_k,
320
+ headdim,
321
+ EVEN_M: tl.constexpr,
322
+ EVEN_N: tl.constexpr,
323
+ EVEN_HEADDIM: tl.constexpr,
324
+ ):
325
  if EVEN_N & EVEN_M:
326
  if EVEN_HEADDIM:
327
  tl.store(dv_ptrs, dv)
 
333
  tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
334
  tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
335
  else:
336
+ tl.store(
337
+ dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
338
+ )
339
+ tl.store(
340
+ dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
341
+ )
342
+
343
 
344
  @triton.jit
345
+ def _bwd_kernel_one_col_block(
346
+ start_n,
347
+ Q,
348
+ K,
349
+ V,
350
+ Bias,
351
+ DO,
352
+ DQ,
353
+ DK,
354
+ DV,
355
+ LSE,
356
+ D,
357
+ softmax_scale,
358
+ stride_qm,
359
+ stride_kn,
360
+ stride_vn,
361
+ stride_bm,
362
+ stride_dom,
363
+ stride_dqm,
364
+ stride_dkn,
365
+ stride_dvn,
366
+ seqlen_q,
367
+ seqlen_k,
368
+ headdim,
369
+ ATOMIC_ADD: tl.constexpr,
370
+ BIAS_TYPE: tl.constexpr,
371
+ IS_CAUSAL: tl.constexpr,
372
+ BLOCK_HEADDIM: tl.constexpr,
373
+ EVEN_M: tl.constexpr,
374
+ EVEN_N: tl.constexpr,
375
+ EVEN_HEADDIM: tl.constexpr,
376
+ BLOCK_M: tl.constexpr,
377
+ BLOCK_N: tl.constexpr,
378
+ ):
379
  begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
380
  offs_qm = begin_m + tl.arange(0, BLOCK_M)
381
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
 
386
  v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
387
  do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
388
  dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
389
+ if BIAS_TYPE == "vector":
390
  b_ptrs = Bias + offs_n
391
+ elif BIAS_TYPE == "matrix":
392
  b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
393
  dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
394
  dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
395
  if begin_m >= seqlen_q:
396
  dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
397
  dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
398
+ _bwd_store_dk_dv(
399
+ dk_ptrs,
400
+ dv_ptrs,
401
+ dk,
402
+ dv,
403
+ offs_n,
404
+ offs_d,
405
+ seqlen_k,
406
+ headdim,
407
+ EVEN_M=EVEN_M,
408
+ EVEN_N=EVEN_N,
409
+ EVEN_HEADDIM=EVEN_HEADDIM,
410
+ )
411
  return
412
  if EVEN_N & EVEN_M:
413
  if EVEN_HEADDIM:
 
420
  k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
421
  v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
422
  else:
423
+ k = tl.load(
424
+ k_ptrs,
425
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
426
+ other=0.0,
427
+ )
428
+ v = tl.load(
429
+ v_ptrs,
430
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
431
+ other=0.0,
432
+ )
433
  num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
434
  for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
435
  start_m = tl.multiple_of(start_m, BLOCK_M)
 
439
  elif EVEN_HEADDIM:
440
  q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
441
  else:
442
+ q = tl.load(
443
+ q_ptrs,
444
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
445
+ other=0.0,
446
+ )
447
  qk = tl.dot(q, k, trans_b=True)
448
  if not EVEN_N:
449
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
450
  if IS_CAUSAL:
451
+ qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
452
+ if BIAS_TYPE != "none":
453
  tl.debug_barrier()
454
+ if BIAS_TYPE == "vector":
455
  if EVEN_N:
456
  bias = tl.load(b_ptrs).to(tl.float32)
457
  else:
458
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(
459
+ tl.float32
460
+ )
461
  bias = bias[None, :]
462
+ elif BIAS_TYPE == "matrix":
463
  if EVEN_M & EVEN_N:
464
  bias = tl.load(b_ptrs).to(tl.float32)
465
  else:
466
+ bias = tl.load(
467
+ b_ptrs,
468
+ mask=(offs_m_curr[:, None] < seqlen_q)
469
+ & (offs_n[None, :] < seqlen_k),
470
+ other=0.0,
471
+ ).to(tl.float32)
472
  qk = qk * softmax_scale + bias
473
  if not EVEN_M & EVEN_HEADDIM:
474
  tl.debug_barrier()
475
  lse_i = tl.load(LSE + offs_m_curr)
476
+ if BIAS_TYPE == "none":
477
  p = tl.exp(qk * softmax_scale - lse_i[:, None])
478
  else:
479
  p = tl.exp(qk - lse_i[:, None])
480
  if EVEN_M & EVEN_HEADDIM:
481
  do = tl.load(do_ptrs)
482
  else:
483
+ do = tl.load(
484
+ do_ptrs,
485
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
486
+ other=0.0,
487
+ )
488
  dv += tl.dot(p.to(do.dtype), do, trans_a=True)
489
  if not EVEN_M & EVEN_HEADDIM:
490
  tl.debug_barrier()
 
498
  tl.debug_barrier()
499
  if not ATOMIC_ADD:
500
  if EVEN_M & EVEN_HEADDIM:
501
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
502
  dq += tl.dot(ds, k)
503
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
504
  elif EVEN_HEADDIM:
505
+ dq = tl.load(
506
+ dq_ptrs,
507
+ mask=offs_m_curr[:, None] < seqlen_q,
508
+ other=0.0,
509
+ eviction_policy="evict_last",
510
+ )
511
  dq += tl.dot(ds, k)
512
+ tl.store(
513
+ dq_ptrs,
514
+ dq,
515
+ mask=offs_m_curr[:, None] < seqlen_q,
516
+ eviction_policy="evict_last",
517
+ )
518
  else:
519
+ dq = tl.load(
520
+ dq_ptrs,
521
+ mask=(offs_m_curr[:, None] < seqlen_q)
522
+ & (offs_d[None, :] < headdim),
523
+ other=0.0,
524
+ eviction_policy="evict_last",
525
+ )
526
  dq += tl.dot(ds, k)
527
+ tl.store(
528
+ dq_ptrs,
529
+ dq,
530
+ mask=(offs_m_curr[:, None] < seqlen_q)
531
+ & (offs_d[None, :] < headdim),
532
+ eviction_policy="evict_last",
533
+ )
534
  else:
535
  dq = tl.dot(ds, k)
536
  if EVEN_M & EVEN_HEADDIM:
 
538
  elif EVEN_HEADDIM:
539
  tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
540
  else:
541
+ tl.atomic_add(
542
+ dq_ptrs,
543
+ dq,
544
+ mask=(offs_m_curr[:, None] < seqlen_q)
545
+ & (offs_d[None, :] < headdim),
546
+ )
547
  dq_ptrs += BLOCK_M * stride_dqm
548
  q_ptrs += BLOCK_M * stride_qm
549
  do_ptrs += BLOCK_M * stride_dom
550
+ if BIAS_TYPE == "matrix":
551
  b_ptrs += BLOCK_M * stride_bm
552
  dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
553
  dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
554
+ _bwd_store_dk_dv(
555
+ dk_ptrs,
556
+ dv_ptrs,
557
+ dk,
558
+ dv,
559
+ offs_n,
560
+ offs_d,
561
+ seqlen_k,
562
+ headdim,
563
+ EVEN_M=EVEN_M,
564
+ EVEN_N=EVEN_N,
565
+ EVEN_HEADDIM=EVEN_HEADDIM,
566
+ )
567
+
568
 
569
  def init_to_zero(name):
570
  return lambda nargs: nargs[name].zero_()
571
 
572
+
573
+ @triton.autotune(
574
+ configs=[
575
+ triton.Config(
576
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
577
+ num_warps=8,
578
+ num_stages=1,
579
+ pre_hook=init_to_zero("DQ"),
580
+ ),
581
+ triton.Config(
582
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
583
+ num_warps=8,
584
+ num_stages=1,
585
+ pre_hook=init_to_zero("DQ"),
586
+ ),
587
+ ],
588
+ key=[
589
+ "CACHE_KEY_SEQLEN_Q",
590
+ "CACHE_KEY_SEQLEN_K",
591
+ "BIAS_TYPE",
592
+ "IS_CAUSAL",
593
+ "BLOCK_HEADDIM",
594
+ ],
595
+ )
596
+ @triton.heuristics(
597
+ {
598
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
599
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
600
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
601
+ }
602
+ )
603
  @triton.jit
604
+ def _bwd_kernel(
605
+ Q,
606
+ K,
607
+ V,
608
+ Bias,
609
+ DO,
610
+ DQ,
611
+ DK,
612
+ DV,
613
+ LSE,
614
+ D,
615
+ softmax_scale,
616
+ stride_qb,
617
+ stride_qh,
618
+ stride_qm,
619
+ stride_kb,
620
+ stride_kh,
621
+ stride_kn,
622
+ stride_vb,
623
+ stride_vh,
624
+ stride_vn,
625
+ stride_bb,
626
+ stride_bh,
627
+ stride_bm,
628
+ stride_dob,
629
+ stride_doh,
630
+ stride_dom,
631
+ stride_dqb,
632
+ stride_dqh,
633
+ stride_dqm,
634
+ stride_dkb,
635
+ stride_dkh,
636
+ stride_dkn,
637
+ stride_dvb,
638
+ stride_dvh,
639
+ stride_dvn,
640
+ nheads,
641
+ seqlen_q,
642
+ seqlen_k,
643
+ seqlen_q_rounded,
644
+ headdim,
645
+ CACHE_KEY_SEQLEN_Q,
646
+ CACHE_KEY_SEQLEN_K,
647
+ BIAS_TYPE: tl.constexpr,
648
+ IS_CAUSAL: tl.constexpr,
649
+ BLOCK_HEADDIM: tl.constexpr,
650
+ SEQUENCE_PARALLEL: tl.constexpr,
651
+ EVEN_M: tl.constexpr,
652
+ EVEN_N: tl.constexpr,
653
+ EVEN_HEADDIM: tl.constexpr,
654
+ BLOCK_M: tl.constexpr,
655
+ BLOCK_N: tl.constexpr,
656
+ ):
657
  off_hb = tl.program_id(1)
658
  off_b = off_hb // nheads
659
  off_h = off_hb % nheads
 
664
  DQ += off_b * stride_dqb + off_h * stride_dqh
665
  DK += off_b * stride_dkb + off_h * stride_dkh
666
  DV += off_b * stride_dvb + off_h * stride_dvh
667
+ if BIAS_TYPE != "none":
668
  Bias += off_b * stride_bb + off_h * stride_bh
669
  D += off_hb * seqlen_q_rounded
670
  LSE += off_hb * seqlen_q_rounded
671
  if not SEQUENCE_PARALLEL:
672
  num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
673
  for start_n in range(0, num_block_n):
674
+ _bwd_kernel_one_col_block(
675
+ start_n,
676
+ Q,
677
+ K,
678
+ V,
679
+ Bias,
680
+ DO,
681
+ DQ,
682
+ DK,
683
+ DV,
684
+ LSE,
685
+ D,
686
+ softmax_scale,
687
+ stride_qm,
688
+ stride_kn,
689
+ stride_vn,
690
+ stride_bm,
691
+ stride_dom,
692
+ stride_dqm,
693
+ stride_dkn,
694
+ stride_dvn,
695
+ seqlen_q,
696
+ seqlen_k,
697
+ headdim,
698
+ ATOMIC_ADD=False,
699
+ BIAS_TYPE=BIAS_TYPE,
700
+ IS_CAUSAL=IS_CAUSAL,
701
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
702
+ EVEN_M=EVEN_M,
703
+ EVEN_N=EVEN_N,
704
+ EVEN_HEADDIM=EVEN_HEADDIM,
705
+ BLOCK_M=BLOCK_M,
706
+ BLOCK_N=BLOCK_N,
707
+ )
708
  else:
709
  start_n = tl.program_id(0)
710
+ _bwd_kernel_one_col_block(
711
+ start_n,
712
+ Q,
713
+ K,
714
+ V,
715
+ Bias,
716
+ DO,
717
+ DQ,
718
+ DK,
719
+ DV,
720
+ LSE,
721
+ D,
722
+ softmax_scale,
723
+ stride_qm,
724
+ stride_kn,
725
+ stride_vn,
726
+ stride_bm,
727
+ stride_dom,
728
+ stride_dqm,
729
+ stride_dkn,
730
+ stride_dvn,
731
+ seqlen_q,
732
+ seqlen_k,
733
+ headdim,
734
+ ATOMIC_ADD=True,
735
+ BIAS_TYPE=BIAS_TYPE,
736
+ IS_CAUSAL=IS_CAUSAL,
737
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
738
+ EVEN_M=EVEN_M,
739
+ EVEN_N=EVEN_N,
740
+ EVEN_HEADDIM=EVEN_HEADDIM,
741
+ BLOCK_M=BLOCK_M,
742
+ BLOCK_N=BLOCK_N,
743
+ )
744
+
745
 
746
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
747
  (batch, seqlen_q, nheads, d) = q.shape
748
  (_, seqlen_k, _, _) = k.shape
749
  assert k.shape == (batch, seqlen_k, nheads, d)
750
  assert v.shape == (batch, seqlen_k, nheads, d)
751
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
752
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
753
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
754
  assert q.is_cuda and k.is_cuda and v.is_cuda
755
  softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
756
  has_bias = bias is not None
757
+ bias_type = "none"
758
  if has_bias:
759
  assert bias.dtype in [q.dtype, torch.float]
760
  assert bias.is_cuda
 
762
  if bias.stride(-1) != 1:
763
  bias = bias.contiguous()
764
  if bias.shape[2:] == (1, seqlen_k):
765
+ bias_type = "vector"
766
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
767
+ bias_type = "matrix"
768
  else:
769
+ raise RuntimeError(
770
+ "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
771
+ )
772
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
773
+ bias_strides = (
774
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
775
+ )
776
  seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
777
+ lse = torch.empty(
778
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
779
+ )
780
+ tmp = torch.empty(
781
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
782
+ )
783
  o = torch.empty_like(q)
784
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
785
  BLOCK = 128
786
  num_warps = 4 if d <= 64 else 8
787
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
788
+ _fwd_kernel[grid](
789
+ q,
790
+ k,
791
+ v,
792
+ bias,
793
+ o,
794
+ lse,
795
+ tmp,
796
+ softmax_scale,
797
+ q.stride(0),
798
+ q.stride(2),
799
+ q.stride(1),
800
+ k.stride(0),
801
+ k.stride(2),
802
+ k.stride(1),
803
+ v.stride(0),
804
+ v.stride(2),
805
+ v.stride(1),
806
+ *bias_strides,
807
+ o.stride(0),
808
+ o.stride(2),
809
+ o.stride(1),
810
+ nheads,
811
+ seqlen_q,
812
+ seqlen_k,
813
+ seqlen_q_rounded,
814
+ d,
815
+ seqlen_q // 32,
816
+ seqlen_k // 32,
817
+ bias_type,
818
+ causal,
819
+ BLOCK_HEADDIM,
820
+ BLOCK_M=BLOCK,
821
+ BLOCK_N=BLOCK,
822
+ num_warps=num_warps,
823
+ num_stages=1
824
+ )
825
  return (o, lse, softmax_scale)
826
 
827
+
828
+ def _flash_attn_backward(
829
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
830
+ ):
831
  if do.stride(-1) != 1:
832
  do = do.contiguous()
833
  (batch, seqlen_q, nheads, d) = q.shape
 
841
  dq_accum = torch.empty_like(q, dtype=torch.float32)
842
  delta = torch.empty_like(lse)
843
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
844
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
845
+ _bwd_preprocess_do_o_dot[grid](
846
+ o,
847
+ do,
848
+ delta,
849
+ o.stride(0),
850
+ o.stride(2),
851
+ o.stride(1),
852
+ do.stride(0),
853
+ do.stride(2),
854
+ do.stride(1),
855
+ nheads,
856
+ seqlen_q,
857
+ seqlen_q_rounded,
858
+ d,
859
+ BLOCK_M=128,
860
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
861
+ )
862
  has_bias = bias is not None
863
+ bias_type = "none"
864
  if has_bias:
865
  assert bias.dtype in [q.dtype, torch.float]
866
  assert bias.is_cuda
867
  assert bias.dim() == 4
868
  assert bias.stride(-1) == 1
869
  if bias.shape[2:] == (1, seqlen_k):
870
+ bias_type = "vector"
871
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
872
+ bias_type = "matrix"
873
  else:
874
+ raise RuntimeError(
875
+ "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
876
+ )
877
  bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
878
+ bias_strides = (
879
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
880
+ )
881
+ grid = lambda META: (
882
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
883
+ batch * nheads,
884
+ )
885
+ _bwd_kernel[grid](
886
+ q,
887
+ k,
888
+ v,
889
+ bias,
890
+ do,
891
+ dq_accum,
892
+ dk,
893
+ dv,
894
+ lse,
895
+ delta,
896
+ softmax_scale,
897
+ q.stride(0),
898
+ q.stride(2),
899
+ q.stride(1),
900
+ k.stride(0),
901
+ k.stride(2),
902
+ k.stride(1),
903
+ v.stride(0),
904
+ v.stride(2),
905
+ v.stride(1),
906
+ *bias_strides,
907
+ do.stride(0),
908
+ do.stride(2),
909
+ do.stride(1),
910
+ dq_accum.stride(0),
911
+ dq_accum.stride(2),
912
+ dq_accum.stride(1),
913
+ dk.stride(0),
914
+ dk.stride(2),
915
+ dk.stride(1),
916
+ dv.stride(0),
917
+ dv.stride(2),
918
+ dv.stride(1),
919
+ nheads,
920
+ seqlen_q,
921
+ seqlen_k,
922
+ seqlen_q_rounded,
923
+ d,
924
+ seqlen_q // 32,
925
+ seqlen_k // 32,
926
+ bias_type,
927
+ causal,
928
+ BLOCK_HEADDIM
929
+ )
930
  dq.copy_(dq_accum)
931
 
932
+
933
  class FlashAttnQKVPackedFunc(torch.autograd.Function):
934
 
935
  @staticmethod
936
  def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
937
  """
938
+ qkv: (batch, seqlen, 3, nheads, headdim)
939
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
940
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
941
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
942
  """
943
  if qkv.stride(-1) != 1:
944
  qkv = qkv.contiguous()
945
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
946
+ qkv[:, :, 0],
947
+ qkv[:, :, 1],
948
+ qkv[:, :, 2],
949
+ bias=bias,
950
+ causal=causal,
951
+ softmax_scale=softmax_scale,
952
+ )
953
  ctx.save_for_backward(qkv, o, lse, bias)
954
  ctx.causal = causal
955
  return o
 
957
  @staticmethod
958
  def backward(ctx, do):
959
  (qkv, o, lse, bias) = ctx.saved_tensors
960
+ assert not ctx.needs_input_grad[
961
+ 1
962
+ ], "FlashAttention does not support bias gradient yet"
963
  with torch.inference_mode():
964
  dqkv = torch.empty_like(qkv)
965
+ _flash_attn_backward(
966
+ do,
967
+ qkv[:, :, 0],
968
+ qkv[:, :, 1],
969
+ qkv[:, :, 2],
970
+ o,
971
+ lse,
972
+ dqkv[:, :, 0],
973
+ dqkv[:, :, 1],
974
+ dqkv[:, :, 2],
975
+ bias=bias,
976
+ causal=ctx.causal,
977
+ softmax_scale=ctx.softmax_scale,
978
+ )
979
  return (dqkv, None, None, None)
980
+
981
+
982
  flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
983
 
984
+
985
  class FlashAttnKVPackedFunc(torch.autograd.Function):
986
 
987
  @staticmethod
988
  def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
989
  """
990
+ q: (batch, seqlen_q, nheads, headdim)
991
+ kv: (batch, seqlen_k, 2, nheads, headdim)
992
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
993
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
994
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
995
  """
996
  (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
997
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
998
+ q,
999
+ kv[:, :, 0],
1000
+ kv[:, :, 1],
1001
+ bias=bias,
1002
+ causal=causal,
1003
+ softmax_scale=softmax_scale,
1004
+ )
1005
  ctx.save_for_backward(q, kv, o, lse, bias)
1006
  ctx.causal = causal
1007
  return o
 
1010
  def backward(ctx, do):
1011
  (q, kv, o, lse, bias) = ctx.saved_tensors
1012
  if len(ctx.needs_input_grad) >= 3:
1013
+ assert not ctx.needs_input_grad[
1014
+ 2
1015
+ ], "FlashAttention does not support bias gradient yet"
1016
  with torch.inference_mode():
1017
  dq = torch.empty_like(q)
1018
  dkv = torch.empty_like(kv)
1019
+ _flash_attn_backward(
1020
+ do,
1021
+ q,
1022
+ kv[:, :, 0],
1023
+ kv[:, :, 1],
1024
+ o,
1025
+ lse,
1026
+ dq,
1027
+ dkv[:, :, 0],
1028
+ dkv[:, :, 1],
1029
+ bias=bias,
1030
+ causal=ctx.causal,
1031
+ softmax_scale=ctx.softmax_scale,
1032
+ )
1033
  return (dq, dkv, None, None, None)
1034
+
1035
+
1036
  flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
1037
 
1038
+
1039
  class FlashAttnFunc(torch.autograd.Function):
1040
 
1041
  @staticmethod
1042
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1043
  """
1044
+ q: (batch_size, seqlen_q, nheads, headdim)
1045
+ k, v: (batch_size, seqlen_k, nheads, headdim)
1046
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1047
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1048
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1049
  """
1050
  (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1051
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
1052
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1053
+ )
1054
  ctx.save_for_backward(q, k, v, o, lse, bias)
1055
  ctx.causal = causal
1056
  return o
 
1058
  @staticmethod
1059
  def backward(ctx, do):
1060
  (q, k, v, o, lse, bias) = ctx.saved_tensors
1061
+ assert not ctx.needs_input_grad[
1062
+ 3
1063
+ ], "FlashAttention does not support bias gradient yet"
1064
  with torch.inference_mode():
1065
  dq = torch.empty_like(q)
1066
  dk = torch.empty_like(k)
1067
  dv = torch.empty_like(v)
1068
+ _flash_attn_backward(
1069
+ do,
1070
+ q,
1071
+ k,
1072
+ v,
1073
+ o,
1074
+ lse,
1075
+ dq,
1076
+ dk,
1077
+ dv,
1078
+ bias=bias,
1079
+ causal=ctx.causal,
1080
+ softmax_scale=ctx.softmax_scale,
1081
+ )
1082
  return (dq, dk, dv, None, None, None)
1083
+
1084
+
1085
+ flash_attn_func = FlashAttnFunc.apply
hf_prefixlm_converter.py CHANGED
@@ -6,6 +6,7 @@ Causal LM to convert it to a Prefix LM.
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
 
9
  from types import MethodType
10
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
11
  import torch
 
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
9
+
10
  from types import MethodType
11
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
12
  import torch
meta_init_context.py CHANGED
@@ -3,8 +3,9 @@ from typing import Any, Callable, Optional
3
  import torch
4
  import torch.nn as nn
5
 
 
6
  @contextmanager
7
- def init_empty_weights(include_buffers: bool=False):
8
  """Meta initialization context manager.
9
 
10
  A context manager under which models are initialized with all parameters
@@ -31,11 +32,12 @@ def init_empty_weights(include_buffers: bool=False):
31
 
32
  </Tip>
33
  """
34
- with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
35
  yield f
36
 
 
37
  @contextmanager
38
- def init_on_device(device: torch.device, include_buffers: bool=False):
39
  """Device initialization context manager.
40
 
41
  A context manager under which models are initialized with all parameters
@@ -58,7 +60,9 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
58
  if include_buffers:
59
  old_register_buffer = nn.Module.register_buffer
60
 
61
- def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]):
 
 
62
  old_register_parameter(self, name, param)
63
  if param is not None:
64
  parameter = self._parameters[name]
@@ -67,33 +71,51 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
67
  kwargs = parameter.__dict__
68
  self._parameters[name] = param_cls(parameter.to(device), **kwargs)
69
 
70
- def register_empty_buffer(self: torch.nn.Module, name: str, tensor: Optional[torch.Tensor], persistent: bool=True):
 
 
 
 
 
71
  old_register_buffer(self, name, tensor, persistent=persistent)
72
  if tensor is not None:
73
  named_buffer = self._buffers[name]
74
  assert named_buffer is not None
75
  self._buffers[name] = named_buffer.to(device)
 
76
  if include_buffers:
77
- tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
 
 
 
78
  else:
79
  tensor_constructors_to_patch = {}
80
 
81
  def patch_tensor_constructor(fn: Callable):
82
 
83
  def wrapper(*args: Any, **kwargs: Any):
84
- kwargs['device'] = device
85
  return fn(*args, **kwargs)
 
86
  return wrapper
 
87
  try:
88
  nn.Module.register_parameter = register_empty_parameter
89
  if include_buffers:
90
  nn.Module.register_buffer = register_empty_buffer
91
  for torch_function_name in tensor_constructors_to_patch.keys():
92
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
 
 
 
 
93
  yield
94
  finally:
95
  nn.Module.register_parameter = old_register_parameter
96
  if include_buffers:
97
  nn.Module.register_buffer = old_register_buffer
98
- for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
99
- setattr(torch, torch_function_name, old_torch_function)
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
 
6
+
7
  @contextmanager
8
+ def init_empty_weights(include_buffers: bool = False):
9
  """Meta initialization context manager.
10
 
11
  A context manager under which models are initialized with all parameters
 
32
 
33
  </Tip>
34
  """
35
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
36
  yield f
37
 
38
+
39
  @contextmanager
40
+ def init_on_device(device: torch.device, include_buffers: bool = False):
41
  """Device initialization context manager.
42
 
43
  A context manager under which models are initialized with all parameters
 
60
  if include_buffers:
61
  old_register_buffer = nn.Module.register_buffer
62
 
63
+ def register_empty_parameter(
64
+ self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]
65
+ ):
66
  old_register_parameter(self, name, param)
67
  if param is not None:
68
  parameter = self._parameters[name]
 
71
  kwargs = parameter.__dict__
72
  self._parameters[name] = param_cls(parameter.to(device), **kwargs)
73
 
74
+ def register_empty_buffer(
75
+ self: torch.nn.Module,
76
+ name: str,
77
+ tensor: Optional[torch.Tensor],
78
+ persistent: bool = True,
79
+ ):
80
  old_register_buffer(self, name, tensor, persistent=persistent)
81
  if tensor is not None:
82
  named_buffer = self._buffers[name]
83
  assert named_buffer is not None
84
  self._buffers[name] = named_buffer.to(device)
85
+
86
  if include_buffers:
87
+ tensor_constructors_to_patch = {
88
+ torch_function_name: getattr(torch, torch_function_name)
89
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
90
+ }
91
  else:
92
  tensor_constructors_to_patch = {}
93
 
94
  def patch_tensor_constructor(fn: Callable):
95
 
96
  def wrapper(*args: Any, **kwargs: Any):
97
+ kwargs["device"] = device
98
  return fn(*args, **kwargs)
99
+
100
  return wrapper
101
+
102
  try:
103
  nn.Module.register_parameter = register_empty_parameter
104
  if include_buffers:
105
  nn.Module.register_buffer = register_empty_buffer
106
  for torch_function_name in tensor_constructors_to_patch.keys():
107
+ setattr(
108
+ torch,
109
+ torch_function_name,
110
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
111
+ )
112
  yield
113
  finally:
114
  nn.Module.register_parameter = old_register_parameter
115
  if include_buffers:
116
  nn.Module.register_buffer = old_register_buffer
117
+ for (
118
+ torch_function_name,
119
+ old_torch_function,
120
+ ) in tensor_constructors_to_patch.items():
121
+ setattr(torch, torch_function_name, old_torch_function)
model-00001-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec05e038ef093118222d561cf5ff110721cebb4033106cbdd9193dc601722cb4
3
  size 4933505648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9b11a607384278e4b241042d9daf5bf81228e712c14512e4b9fc8a456e3447b
3
  size 4933505648
model-00002-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9c0acf0dc6025344ced5ab450d75e5bfcf62868e78a0c060c5c9fffdb286deb
3
  size 4967831752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1a3ca63ac33e1f432ef6f9fa53e04912eb8c8a2093991d07fd8bce6a34708bf
3
  size 4967831752
model-00003-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a887263b29d7b702c46295b48d28c95086ba4b9f666900577b91836e722a420
3
  size 4967781776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f591ec5d6445a4193146bf8782173219a1d13bb1f004468116e6d4e8276efa4
3
  size 4967781776
model-00004-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:089b687a1102645e904f39c98094ace78ec3ed638124ec36d2a8ac11d402a047
3
  size 134242752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56a7ffbb1197228d6e4cb970755f61ec74925294945602c76f5a8584e4da7952
3
  size 134242752
modeling_mpt.py CHANGED
@@ -2,24 +2,42 @@
2
 
3
  Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
 
 
5
  import math
6
  import warnings
7
  from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from transformers import PreTrainedModel, PreTrainedTokenizerBase
12
  from transformers.modeling_outputs import (
13
  BaseModelOutputWithPast,
14
  CausalLMOutputWithPast,
15
  )
16
-
17
- from .attention import (
18
- MultiheadAttention,
19
- MultiQueryAttention,
20
- attn_bias_shape,
21
- build_attn_bias,
 
 
22
  )
 
23
  from .blocks import MPTBlock
24
  from .custom_embedding import SharedEmbedding
25
  from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
@@ -45,22 +63,216 @@ import logging
45
  log = logging.getLogger(__name__)
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  class MPTPreTrainedModel(PreTrainedModel):
49
  config_class = MPTConfig
50
  base_model_prefix = "model"
51
  _no_split_modules = ["MPTBlock"]
 
52
  supports_gradient_checkpointing = True
53
 
54
- def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
55
- if (
56
- isinstance(module, MPTModel)
57
- or isinstance(module, MultiheadAttention)
58
- or isinstance(module, MultiQueryAttention)
59
- ):
60
- module.gradient_checkpointing = value
61
 
62
 
63
  class MPTModel(MPTPreTrainedModel):
 
64
  def __init__(self, config: MPTConfig):
65
  config._validate_config()
66
  super().__init__(config)
@@ -98,6 +310,18 @@ class MPTModel(MPTPreTrainedModel):
98
  ]
99
  )
100
  self.norm_f = norm_class(config.d_model, device=config.init_device)
 
 
 
 
 
 
 
 
 
 
 
 
101
  if config.init_device != "meta":
102
  log.info(
103
  f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
@@ -118,18 +342,18 @@ class MPTModel(MPTPreTrainedModel):
118
  if config.no_bias:
119
  for module in self.modules():
120
  if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
121
- log.info(f"Removing bias ({module.bias}) from {module}.")
122
  module.register_parameter("bias", None)
123
  if hasattr(module, "use_bias"):
124
- log.info(f"Setting use_bias=False for {module}.")
125
  module.use_bias = False
126
  log.debug(self)
127
  log.debug(f"Using {self.config.init_config['name']} initialization.")
128
 
129
- def get_input_embeddings(self) -> nn.Embedding:
130
  return self.wte
131
 
132
- def set_input_embeddings(self, value: nn.Embedding) -> None:
133
  self.wte = value
134
 
135
  @torch.no_grad()
@@ -167,7 +391,9 @@ class MPTModel(MPTPreTrainedModel):
167
  attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
168
  if self.attn_uses_sequence_id and sequence_id is not None:
169
  assert isinstance(attn_bias, torch.Tensor)
170
- attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
 
 
171
  if attention_mask is not None:
172
  s_k = attention_mask.shape[-1]
173
  if attn_bias is None:
@@ -184,7 +410,7 @@ class MPTModel(MPTPreTrainedModel):
184
  attn_bias = attn_bias.masked_fill(
185
  ~attention_mask.view(-1, 1, 1, s_k), min_val
186
  )
187
- return (attn_bias, None)
188
 
189
  def _apply_prefix_mask(
190
  self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor
@@ -211,25 +437,9 @@ class MPTModel(MPTPreTrainedModel):
211
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
212
  return attn_bias
213
 
214
- def _apply_sequence_id(
215
- self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
216
- ) -> torch.Tensor:
217
- seq_len = sequence_id.shape[-1]
218
- if seq_len > self.config.max_seq_len:
219
- raise ValueError(
220
- f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
221
- )
222
- attn_bias = attn_bias[..., :seq_len, :seq_len]
223
- cannot_attend = torch.logical_not(
224
- torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
225
- ).unsqueeze(1)
226
- min_val = torch.finfo(attn_bias.dtype).min
227
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
228
- return attn_bias
229
-
230
  def forward(
231
  self,
232
- input_ids: torch.LongTensor,
233
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
234
  attention_mask: Optional[torch.ByteTensor] = None,
235
  prefix_mask: Optional[torch.ByteTensor] = None,
@@ -244,9 +454,6 @@ class MPTModel(MPTPreTrainedModel):
244
  return_dict if return_dict is not None else self.config.return_dict
245
  )
246
  use_cache = use_cache if use_cache is not None else self.config.use_cache
247
- if self.gradient_checkpointing and self.training:
248
- if use_cache:
249
- use_cache = False
250
  if attention_mask is not None:
251
  attention_mask = attention_mask.bool()
252
  if prefix_mask is not None:
@@ -272,8 +479,6 @@ class MPTModel(MPTPreTrainedModel):
272
  raise ValueError(
273
  "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
274
  )
275
- if inputs_embeds is not None:
276
- raise NotImplementedError("inputs_embeds is not implemented for MPT.")
277
  if self.training:
278
  if self.attn_uses_sequence_id and sequence_id is None:
279
  raise ValueError(
@@ -285,53 +490,78 @@ class MPTModel(MPTPreTrainedModel):
285
  "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
286
  + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
287
  )
288
- S = input_ids.size(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  assert (
290
  S <= self.config.max_seq_len
291
  ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
292
- tok_emb = self.wte(input_ids)
293
- if self.learned_pos_emb:
294
- past_position = 0
295
- if past_key_values is not None:
296
- if len(past_key_values) != self.config.n_layers:
297
- raise ValueError(
298
- f"past_key_values must provide a past_key_value for each attention "
299
- + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
300
- )
301
- past_position = past_key_values[0][0].size(1)
302
- if self.attn_impl == "torch":
303
- past_position = past_key_values[0][0].size(3)
304
- if S + past_position > self.config.max_seq_len:
305
  raise ValueError(
306
  f"Cannot forward input with past sequence length {past_position} and current sequence length "
307
  + f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
308
  )
309
- # print(past_position)
310
- # print(S + past_position)
311
- pos = torch.arange(
312
- past_position,
313
- S + past_position,
314
- dtype=torch.long,
315
- device=input_ids.device,
316
- ).unsqueeze(0)
317
- # print(pos)
318
- if attention_mask is not None:
319
- # print(torch.cumsum((~attention_mask).to(torch.int32), dim=1))
320
- pos = torch.clamp(
321
- pos
322
- - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
323
- :, past_position:
324
- ],
325
- min=0,
326
- )
327
- # print(pos)
328
- # print(attention_mask)
329
- pos_emb = self.wpe(pos)
330
- # print(pos_emb)
331
- x = tok_emb + pos_emb
332
-
333
- else:
334
- x = tok_emb
 
 
 
 
 
335
  if self.embedding_fraction == 1:
336
  x = self.emb_drop(x)
337
  else:
@@ -347,12 +577,36 @@ class MPTModel(MPTPreTrainedModel):
347
  prefix_mask=prefix_mask,
348
  sequence_id=sequence_id,
349
  )
350
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  presents = () if use_cache else None
352
  if use_cache and past_key_values is None:
353
  past_key_values = [() for _ in range(self.config.n_layers)]
354
  all_hidden_states = () if output_hidden_states else None
355
  all_self_attns = () if output_attentions else None
 
 
 
 
 
 
 
 
 
 
356
  for b_idx, block in enumerate(self.blocks):
357
  if output_hidden_states:
358
  assert all_hidden_states is not None
@@ -360,35 +614,31 @@ class MPTModel(MPTPreTrainedModel):
360
  past_key_value = (
361
  past_key_values[b_idx] if past_key_values is not None else None
362
  )
363
-
364
  if self.gradient_checkpointing and self.training:
365
-
366
- def create_custom_forward(module):
367
- def custom_forward(*inputs):
368
- # None for past_key_value
369
- return module(*inputs)
370
-
371
- return custom_forward
372
-
373
- (x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
374
- create_custom_forward(block),
375
  x,
376
  past_key_value,
377
  attn_bias,
 
378
  attention_mask,
379
  self.is_causal,
380
  bool(output_attentions),
 
 
381
  )
382
  else:
383
  (x, attn_weights, present) = block(
384
  x,
385
  past_key_value=past_key_value,
386
  attn_bias=attn_bias,
 
387
  attention_mask=attention_mask,
388
  is_causal=self.is_causal,
389
  output_attentions=bool(output_attentions),
 
 
390
  )
391
-
392
  if presents is not None:
393
  presents += (present,)
394
  if output_attentions:
@@ -415,19 +665,24 @@ class MPTModel(MPTPreTrainedModel):
415
  )
416
 
417
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
418
- return isinstance(module, MPTBlock)
419
 
420
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
421
  return isinstance(module, MPTBlock)
422
 
423
 
424
  class MPTForCausalLM(MPTPreTrainedModel):
 
425
  def __init__(self, config: MPTConfig):
426
  super().__init__(config)
427
- if not config.tie_word_embeddings:
428
- raise ValueError("MPTForCausalLM only supports tied word embeddings")
429
  log.info(f"Instantiating an MPTForCausalLM model from {__file__}")
430
  self.transformer: MPTModel = MPTModel(config)
 
 
 
 
 
 
431
  for child in self.transformer.children():
432
  if isinstance(child, torch.nn.ModuleList):
433
  continue
@@ -445,19 +700,38 @@ class MPTForCausalLM(MPTPreTrainedModel):
445
  )
446
  self.logit_scale = logit_scale
447
 
448
- def get_input_embeddings(self) -> nn.Embedding:
449
- return self.transformer.wte
450
 
451
  def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
452
- self.transformer.wte = value
453
 
454
- def get_output_embeddings(self) -> nn.Embedding:
455
- return self.transformer.wte
 
 
456
 
457
  def set_output_embeddings(
458
- self, new_embeddings: Union[SharedEmbedding, nn.Embedding]
459
  ) -> None:
460
- self.transformer.wte = new_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
  def set_decoder(self, decoder: MPTModel) -> None:
463
  self.transformer = decoder
@@ -467,7 +741,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
467
 
468
  def forward(
469
  self,
470
- input_ids: torch.LongTensor,
471
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
472
  attention_mask: Optional[torch.ByteTensor] = None,
473
  prefix_mask: Optional[torch.ByteTensor] = None,
@@ -483,10 +757,6 @@ class MPTForCausalLM(MPTPreTrainedModel):
483
  return_dict if return_dict is not None else self.config.return_dict
484
  )
485
  use_cache = use_cache if use_cache is not None else self.config.use_cache
486
- if inputs_embeds is not None:
487
- raise NotImplementedError(
488
- "inputs_embeds has to be None (for hf/peft support)."
489
- )
490
  outputs = self.transformer(
491
  input_ids=input_ids,
492
  past_key_values=past_key_values,
@@ -497,10 +767,14 @@ class MPTForCausalLM(MPTPreTrainedModel):
497
  output_attentions=output_attentions,
498
  output_hidden_states=output_hidden_states,
499
  use_cache=use_cache,
 
500
  )
501
- logits = self.transformer.wte(
502
- outputs.last_hidden_state.to(self.transformer.wte.weight.device), True
503
- )
 
 
 
504
  if self.logit_scale is not None:
505
  if self.logit_scale == 0:
506
  warnings.warn(
@@ -532,10 +806,45 @@ class MPTForCausalLM(MPTPreTrainedModel):
532
  )
533
 
534
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
535
- return isinstance(module, MPTBlock)
536
 
537
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
538
- return isinstance(module, MPTBlock)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
 
540
  def prepare_inputs_for_generation(
541
  self,
@@ -544,8 +853,6 @@ class MPTForCausalLM(MPTPreTrainedModel):
544
  inputs_embeds: Optional[torch.Tensor] = None,
545
  **kwargs: Any,
546
  ) -> Dict[str, Any]:
547
- if inputs_embeds is not None:
548
- raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
549
  attention_mask = kwargs["attention_mask"].bool()
550
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
551
  raise NotImplementedError(
@@ -565,14 +872,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
565
  )
566
  else:
567
  prefix_mask = None
568
- return {
569
- "input_ids": input_ids,
570
- "attention_mask": attention_mask,
571
- "prefix_mask": prefix_mask,
572
- "sequence_id": sequence_id,
573
- "past_key_values": past_key_values,
574
- "use_cache": kwargs.get("use_cache", True),
575
- }
 
 
 
 
 
 
576
 
577
  @staticmethod
578
  def _reorder_cache(
 
2
 
3
  Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
+
6
+ from __future__ import annotations
7
  import math
8
  import warnings
9
  from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
+ from .attention import is_flash_v1_installed, is_flash_v2_installed
14
+
15
+ if is_flash_v2_installed():
16
+ try:
17
+ from flash_attn import bert_padding
18
+ from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
19
+ except Exception as e:
20
+ raise e
21
+ if is_flash_v1_installed():
22
+ try:
23
+ from flash_attn import bert_padding
24
+ except Exception as e:
25
+ raise e
26
  from transformers import PreTrainedModel, PreTrainedTokenizerBase
27
  from transformers.modeling_outputs import (
28
  BaseModelOutputWithPast,
29
  CausalLMOutputWithPast,
30
  )
31
+ from transformers.models.llama.modeling_llama import (
32
+ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding,
33
+ )
34
+ from transformers.models.llama.modeling_llama import (
35
+ LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding,
36
+ )
37
+ from transformers.models.llama.modeling_llama import (
38
+ LlamaRotaryEmbedding as HFRotaryEmbedding,
39
  )
40
+ from .attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
41
  from .blocks import MPTBlock
42
  from .custom_embedding import SharedEmbedding
43
  from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
 
63
  log = logging.getLogger(__name__)
64
 
65
 
66
+ def gen_rotary_embedding(
67
+ rope_head_dim: int,
68
+ rope_impl: str,
69
+ rope_theta: int,
70
+ rope_dail_config: dict,
71
+ rope_hf_config: dict,
72
+ max_seq_len: int,
73
+ ):
74
+ if rope_impl == "dail":
75
+ return DAILRotaryEmbedding(
76
+ dim=rope_head_dim,
77
+ base=rope_theta,
78
+ interleaved=False,
79
+ scale_base=(
80
+ rope_dail_config["xpos_scale_base"]
81
+ if rope_dail_config["type"] == "xpos"
82
+ else None
83
+ ),
84
+ pos_idx_in_fp32=rope_dail_config["pos_idx_in_fp32"],
85
+ device="cpu",
86
+ )
87
+ elif rope_impl == "hf":
88
+ if rope_hf_config["type"] == "no_scaling":
89
+ return HFRotaryEmbedding(
90
+ rope_head_dim,
91
+ max_position_embeddings=max_seq_len,
92
+ base=rope_theta,
93
+ device="cpu",
94
+ )
95
+ elif rope_hf_config["type"] == "linear":
96
+ return HFLinearScalingRotaryEmbedding(
97
+ rope_head_dim,
98
+ max_position_embeddings=max_seq_len,
99
+ base=rope_theta,
100
+ scaling_factor=rope_hf_config["factor"],
101
+ device="cpu",
102
+ )
103
+ elif rope_hf_config["type"] == "dynamic":
104
+ return HFDynamicNTKScalingRotaryEmbedding(
105
+ rope_head_dim,
106
+ max_position_embeddings=max_seq_len,
107
+ base=rope_theta,
108
+ scaling_factor=rope_hf_config["factor"],
109
+ device="cpu",
110
+ )
111
+ raise ValueError("rope_impl needs to be either dail or hf")
112
+
113
+
114
+ def gen_attention_mask_in_length(
115
+ sequence_id: Union[None, torch.Tensor],
116
+ S: int,
117
+ attn_uses_sequence_id: bool,
118
+ attn_impl: str,
119
+ attention_mask: Union[torch.Tensor, None],
120
+ ):
121
+ """Generates the attention mask used for sequence masking in FA v2.
122
+
123
+ Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
124
+ In case of left padding:
125
+ 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
126
+ 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
127
+
128
+ Args:
129
+ sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
130
+ S (int): Sequence length
131
+ attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
132
+ attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
133
+ attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
134
+
135
+ Returns:
136
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
137
+ ```
138
+ [
139
+ [2, 3, 0, 0, 0, 0],
140
+ [3, 2, 0, 0, 0, 0],
141
+ [6, 0, 0, 0, 0, 0]
142
+ ]
143
+ ```
144
+ , which refers to the 3D-attention mask:
145
+ ```
146
+ [
147
+ [
148
+ [1, 0, 0, 0, 0, 0],
149
+ [1, 1, 0, 0, 0, 0],
150
+ [0, 0, 1, 0, 0, 0],
151
+ [0, 0, 1, 1, 0, 0],
152
+ [0, 0, 1, 1, 1, 0],
153
+ [0, 0, 0, 0, 0, 1]
154
+ ],
155
+ [
156
+ [1, 0, 0, 0, 0, 0],
157
+ [1, 1, 0, 0, 0, 0],
158
+ [1, 1, 1, 0, 0, 0],
159
+ [0, 0, 0, 1, 0, 0],
160
+ [0, 0, 0, 1, 1, 0],
161
+ [0, 0, 0, 0, 0, 1]
162
+ ],
163
+ [
164
+ [1, 0, 0, 0, 0, 0],
165
+ [1, 1, 0, 0, 0, 0],
166
+ [1, 1, 1, 0, 0, 0],
167
+ [1, 1, 1, 1, 0, 0],
168
+ [1, 1, 1, 1, 1, 0],
169
+ [1, 1, 1, 1, 1, 1]
170
+ ]
171
+ ]
172
+ ```.
173
+ (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .)
174
+ """
175
+ attention_mask_in_length = None
176
+ if sequence_id is not None and attn_uses_sequence_id and (attn_impl == "flash"):
177
+ if (
178
+ attention_mask is not None
179
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
180
+ ):
181
+ raise NotImplementedError(
182
+ "Left padding is not supported with flash attention when attn_uses_sequence_id is set to True."
183
+ )
184
+ if S != sequence_id.shape[-1]:
185
+ raise ValueError(
186
+ f"Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]})."
187
+ )
188
+ if attention_mask is not None:
189
+ sequence_id = sequence_id.masked_fill(~attention_mask, 0)
190
+ attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
191
+ if attention_mask is not None:
192
+ attention_mask_in_length = attention_mask_in_length.masked_fill(
193
+ ~attention_mask.unsqueeze(-1), 0
194
+ )
195
+ attention_mask_in_length = attention_mask_in_length.sum(dim=1)
196
+ attention_mask_in_length = torch.nn.functional.pad(
197
+ attention_mask_in_length,
198
+ (0, S - attention_mask_in_length.shape[-1]),
199
+ mode="constant",
200
+ value=0,
201
+ )
202
+ return attention_mask_in_length
203
+
204
+
205
+ def gen_flash_attn_padding_info(
206
+ bsz: int,
207
+ S: int,
208
+ past_key_len: int,
209
+ device: torch.device,
210
+ attention_mask_in_length: Optional[torch.Tensor] = None,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ ):
213
+ flash_attn_padding_info = {}
214
+ if attention_mask_in_length is None:
215
+ key_padding_mask = attention_mask
216
+ if key_padding_mask is None:
217
+ key_padding_mask = torch.ones(
218
+ (bsz, past_key_len + S), dtype=torch.bool, device=device
219
+ )
220
+ query_padding_mask = key_padding_mask[:, -S:]
221
+ unpadding_function = bert_padding.unpad_input
222
+ else:
223
+ key_padding_mask = attention_mask_in_length
224
+ query_padding_mask = attention_mask_in_length
225
+ unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
226
+ (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(
227
+ torch.empty(bsz, S, 1, device=device), query_padding_mask
228
+ )
229
+ (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(
230
+ torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask
231
+ )
232
+ (_, indices_v, _, _) = unpadding_function(
233
+ torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask
234
+ )
235
+ flash_attn_padding_info["indices_q"] = indices_q
236
+ flash_attn_padding_info["indices_k"] = indices_k
237
+ flash_attn_padding_info["indices_v"] = indices_v
238
+ flash_attn_padding_info["cu_seqlens_q"] = cu_seqlens_q
239
+ flash_attn_padding_info["cu_seqlens_k"] = cu_seqlens_k
240
+ flash_attn_padding_info["max_seqlen_q"] = max_seqlen_q
241
+ flash_attn_padding_info["max_seqlen_k"] = max_seqlen_k
242
+ return flash_attn_padding_info
243
+
244
+
245
+ def apply_sequence_id(
246
+ attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int
247
+ ) -> torch.Tensor:
248
+ seq_len = sequence_id.shape[-1]
249
+ if seq_len > max_seq_len:
250
+ raise ValueError(
251
+ f"sequence_id sequence length cannot exceed max_seq_len={max_seq_len}"
252
+ )
253
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
254
+ cannot_attend = torch.logical_not(
255
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
256
+ ).unsqueeze(1)
257
+ min_val = torch.finfo(attn_bias.dtype).min
258
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
259
+ return attn_bias
260
+
261
+
262
  class MPTPreTrainedModel(PreTrainedModel):
263
  config_class = MPTConfig
264
  base_model_prefix = "model"
265
  _no_split_modules = ["MPTBlock"]
266
+ _supports_flash_attn_2 = True
267
  supports_gradient_checkpointing = True
268
 
269
+
270
+ def _fsdp_wrap_fn(self: Union[MPTModel, MPTForCausalLM], module: nn.Module) -> bool:
271
+ return isinstance(module, MPTBlock)
 
 
 
 
272
 
273
 
274
  class MPTModel(MPTPreTrainedModel):
275
+
276
  def __init__(self, config: MPTConfig):
277
  config._validate_config()
278
  super().__init__(config)
 
310
  ]
311
  )
312
  self.norm_f = norm_class(config.d_model, device=config.init_device)
313
+ self.rope = config.attn_config["rope"]
314
+ self.rope_impl = None
315
+ if self.rope:
316
+ self.rope_impl = config.attn_config["rope_impl"]
317
+ self.rotary_embedding = gen_rotary_embedding(
318
+ rope_head_dim=config.d_model // config.n_heads,
319
+ rope_impl=self.rope_impl,
320
+ rope_theta=config.attn_config["rope_theta"],
321
+ rope_dail_config=config.attn_config["rope_dail_config"],
322
+ rope_hf_config=config.attn_config["rope_hf_config"],
323
+ max_seq_len=self.config.max_seq_len,
324
+ )
325
  if config.init_device != "meta":
326
  log.info(
327
  f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
 
342
  if config.no_bias:
343
  for module in self.modules():
344
  if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
345
+ log.info(f"Removing bias from module={module!r}.")
346
  module.register_parameter("bias", None)
347
  if hasattr(module, "use_bias"):
348
+ log.info(f"Setting use_bias=False for module={module!r}.")
349
  module.use_bias = False
350
  log.debug(self)
351
  log.debug(f"Using {self.config.init_config['name']} initialization.")
352
 
353
+ def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
354
  return self.wte
355
 
356
+ def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
357
  self.wte = value
358
 
359
  @torch.no_grad()
 
391
  attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
392
  if self.attn_uses_sequence_id and sequence_id is not None:
393
  assert isinstance(attn_bias, torch.Tensor)
394
+ attn_bias = apply_sequence_id(
395
+ attn_bias, sequence_id, self.config.max_seq_len
396
+ )
397
  if attention_mask is not None:
398
  s_k = attention_mask.shape[-1]
399
  if attn_bias is None:
 
410
  attn_bias = attn_bias.masked_fill(
411
  ~attention_mask.view(-1, 1, 1, s_k), min_val
412
  )
413
+ return (attn_bias, attention_mask)
414
 
415
  def _apply_prefix_mask(
416
  self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor
 
437
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
438
  return attn_bias
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  def forward(
441
  self,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
444
  attention_mask: Optional[torch.ByteTensor] = None,
445
  prefix_mask: Optional[torch.ByteTensor] = None,
 
454
  return_dict if return_dict is not None else self.config.return_dict
455
  )
456
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
457
  if attention_mask is not None:
458
  attention_mask = attention_mask.bool()
459
  if prefix_mask is not None:
 
479
  raise ValueError(
480
  "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
481
  )
 
 
482
  if self.training:
483
  if self.attn_uses_sequence_id and sequence_id is None:
484
  raise ValueError(
 
490
  "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
491
  + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
492
  )
493
+
494
+ if self.gradient_checkpointing and self.training and use_cache:
495
+ warnings.warn(
496
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
497
+ )
498
+ use_cache = False
499
+
500
+ if input_ids is not None and inputs_embeds is not None:
501
+ raise ValueError("You cannot specify both input_ids and inputs_embeds.")
502
+ elif input_ids is not None:
503
+ bsz = input_ids.size(0)
504
+ S = input_ids.size(1)
505
+ x = self.wte(input_ids)
506
+ input_device = input_ids.device
507
+ elif inputs_embeds is not None:
508
+ bsz = inputs_embeds.size(0)
509
+ S = inputs_embeds.size(1)
510
+ x = inputs_embeds
511
+ input_device = inputs_embeds.device
512
+ else:
513
+ raise ValueError("You must specify input_ids or inputs_embeds")
514
  assert (
515
  S <= self.config.max_seq_len
516
  ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
517
+ rotary_emb_w_meta_info = None
518
+ past_position = 0
519
+ if past_key_values is not None:
520
+ if len(past_key_values) != self.config.n_layers:
521
+ raise ValueError(
522
+ f"past_key_values must provide a past_key_value for each attention "
523
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
524
+ )
525
+ past_position = past_key_values[0][0].size(1)
526
+ if self.attn_impl == "torch":
527
+ past_position = past_key_values[0][0].size(3)
528
+ if self.learned_pos_emb or self.rope:
529
+ if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
530
  raise ValueError(
531
  f"Cannot forward input with past sequence length {past_position} and current sequence length "
532
  + f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
533
  )
534
+ if self.learned_pos_emb or (self.rope and self.rope_impl == "hf"):
535
+ pos = torch.arange(
536
+ past_position,
537
+ S + past_position,
538
+ dtype=torch.long,
539
+ device=input_device,
540
+ ).unsqueeze(0)
541
+ if attention_mask is not None:
542
+ pos = torch.clamp(
543
+ pos
544
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
545
+ :, past_position:
546
+ ],
547
+ min=0,
548
+ )
549
+ if self.learned_pos_emb:
550
+ x = x + self.wpe(pos)
551
+ elif self.rope and self.rope_impl == "hf":
552
+ rotary_emb_w_meta_info = {
553
+ "impl": self.rope_impl,
554
+ "rotary_emb": self.rotary_embedding,
555
+ "offset_info": pos,
556
+ "seq_len": S + past_position,
557
+ }
558
+ elif self.rope and self.rope_impl == "dail":
559
+ rotary_emb_w_meta_info = {
560
+ "impl": self.rope_impl,
561
+ "rotary_emb": self.rotary_embedding,
562
+ "offset_info": past_position,
563
+ "seq_len": S + past_position,
564
+ }
565
  if self.embedding_fraction == 1:
566
  x = self.emb_drop(x)
567
  else:
 
577
  prefix_mask=prefix_mask,
578
  sequence_id=sequence_id,
579
  )
580
+ attention_mask_in_length = gen_attention_mask_in_length(
581
+ sequence_id=sequence_id,
582
+ S=S,
583
+ attn_uses_sequence_id=self.attn_uses_sequence_id,
584
+ attn_impl=self.attn_impl,
585
+ attention_mask=attention_mask,
586
+ )
587
+ alibi_slopes = None
588
+ if self.alibi and self.attn_impl == "flash":
589
+ alibi_slopes = gen_slopes(
590
+ n_heads=self.config.n_heads,
591
+ alibi_bias_max=self.alibi_bias_max,
592
+ device=x.device,
593
+ return_1d=True,
594
+ )
595
  presents = () if use_cache else None
596
  if use_cache and past_key_values is None:
597
  past_key_values = [() for _ in range(self.config.n_layers)]
598
  all_hidden_states = () if output_hidden_states else None
599
  all_self_attns = () if output_attentions else None
600
+ flash_attn_padding_info = {}
601
+ if self.attn_impl == "flash":
602
+ flash_attn_padding_info = gen_flash_attn_padding_info(
603
+ bsz,
604
+ S,
605
+ past_position,
606
+ x.device,
607
+ attention_mask_in_length,
608
+ attention_mask,
609
+ )
610
  for b_idx, block in enumerate(self.blocks):
611
  if output_hidden_states:
612
  assert all_hidden_states is not None
 
614
  past_key_value = (
615
  past_key_values[b_idx] if past_key_values is not None else None
616
  )
 
617
  if self.gradient_checkpointing and self.training:
618
+ (x, attn_weights, present) = self._gradient_checkpointing_func(
619
+ block.__call__,
 
 
 
 
 
 
 
 
620
  x,
621
  past_key_value,
622
  attn_bias,
623
+ rotary_emb_w_meta_info,
624
  attention_mask,
625
  self.is_causal,
626
  bool(output_attentions),
627
+ alibi_slopes,
628
+ flash_attn_padding_info,
629
  )
630
  else:
631
  (x, attn_weights, present) = block(
632
  x,
633
  past_key_value=past_key_value,
634
  attn_bias=attn_bias,
635
+ rotary_emb_w_meta_info=rotary_emb_w_meta_info,
636
  attention_mask=attention_mask,
637
  is_causal=self.is_causal,
638
  output_attentions=bool(output_attentions),
639
+ alibi_slopes=alibi_slopes,
640
+ flash_attn_padding_info=flash_attn_padding_info,
641
  )
 
642
  if presents is not None:
643
  presents += (present,)
644
  if output_attentions:
 
665
  )
666
 
667
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
668
+ return _fsdp_wrap_fn(self, module)
669
 
670
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
671
  return isinstance(module, MPTBlock)
672
 
673
 
674
  class MPTForCausalLM(MPTPreTrainedModel):
675
+
676
  def __init__(self, config: MPTConfig):
677
  super().__init__(config)
 
 
678
  log.info(f"Instantiating an MPTForCausalLM model from {__file__}")
679
  self.transformer: MPTModel = MPTModel(config)
680
+ self.lm_head = None
681
+ if not config.tie_word_embeddings:
682
+ self.lm_head = nn.Linear(
683
+ config.d_model, config.vocab_size, bias=False, device=config.init_device
684
+ )
685
+ self.lm_head._fsdp_wrap = True
686
  for child in self.transformer.children():
687
  if isinstance(child, torch.nn.ModuleList):
688
  continue
 
700
  )
701
  self.logit_scale = logit_scale
702
 
703
+ def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
704
+ return self.transformer.get_input_embeddings()
705
 
706
  def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
707
+ self.transformer.set_input_embeddings(value)
708
 
709
+ def get_output_embeddings(self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
710
+ if self.lm_head is not None:
711
+ return self.lm_head
712
+ return self.transformer.get_input_embeddings()
713
 
714
  def set_output_embeddings(
715
+ self, new_embeddings: Union[SharedEmbedding, nn.Embedding, nn.Linear]
716
  ) -> None:
717
+ if self.lm_head is not None:
718
+ self.lm_head = new_embeddings
719
+ else:
720
+ if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
721
+ raise ValueError(
722
+ "new_embeddings must be an instance of SharedEmbedding "
723
+ + f"or nn.Embedding, but got {type(new_embeddings)}."
724
+ )
725
+ warnings.warn(
726
+ "Using `set_output_embeddings` to set the embedding layer of "
727
+ + "MPTForCausalLM with tied weights. Given weights are tied, "
728
+ + "using `set_input_embeddings` is recommended over using "
729
+ + "`set_output_embeddings`."
730
+ )
731
+ self.transformer.set_input_embeddings(new_embeddings)
732
+
733
+ def tie_weights(self) -> None:
734
+ self.lm_head = None
735
 
736
  def set_decoder(self, decoder: MPTModel) -> None:
737
  self.transformer = decoder
 
741
 
742
  def forward(
743
  self,
744
+ input_ids: Optional[torch.LongTensor] = None,
745
  past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
746
  attention_mask: Optional[torch.ByteTensor] = None,
747
  prefix_mask: Optional[torch.ByteTensor] = None,
 
757
  return_dict if return_dict is not None else self.config.return_dict
758
  )
759
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
760
  outputs = self.transformer(
761
  input_ids=input_ids,
762
  past_key_values=past_key_values,
 
767
  output_attentions=output_attentions,
768
  output_hidden_states=output_hidden_states,
769
  use_cache=use_cache,
770
+ inputs_embeds=inputs_embeds,
771
  )
772
+ if self.lm_head is not None:
773
+ logits = self.lm_head(outputs.last_hidden_state)
774
+ else:
775
+ out = outputs.last_hidden_state
776
+ out = out.to(self.transformer.wte.weight.device)
777
+ logits = self.transformer.wte(out, True)
778
  if self.logit_scale is not None:
779
  if self.logit_scale == 0:
780
  warnings.warn(
 
806
  )
807
 
808
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
809
+ return _fsdp_wrap_fn(self, module)
810
 
811
  def activation_checkpointing_fn(self, module: nn.Module) -> bool:
812
+ act_ckpt_list = getattr(
813
+ self.config, "activation_checkpointing_target", None
814
+ ) or ["MPTBlock"]
815
+ if isinstance(act_ckpt_list, str):
816
+ act_ckpt_list = [act_ckpt_list]
817
+ elif not isinstance(act_ckpt_list, list):
818
+ raise ValueError(
819
+ f"activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}"
820
+ )
821
+ if "MPTBlock" in act_ckpt_list or "mptblock" in act_ckpt_list:
822
+ if len(act_ckpt_list) > 1:
823
+ log.info(
824
+ "Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target)."
825
+ )
826
+ return isinstance(module, MPTBlock)
827
+ mod_types = ()
828
+ for mod_name in act_ckpt_list:
829
+ if mod_name.lower() == "mptblock":
830
+ mod_types += (MPTBlock,)
831
+ elif mod_name in ATTN_CLASS_REGISTRY:
832
+ mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
833
+ elif mod_name in FFN_CLASS_REGISTRY:
834
+ mod_types += (FFN_CLASS_REGISTRY[mod_name],)
835
+ elif mod_name in NORM_CLASS_REGISTRY:
836
+ mod_types += (NORM_CLASS_REGISTRY[mod_name],)
837
+ else:
838
+ msg = ", ".join(
839
+ list(ATTN_CLASS_REGISTRY.keys())
840
+ + list(FFN_CLASS_REGISTRY.keys())
841
+ + list(NORM_CLASS_REGISTRY.keys())
842
+ + ["MPTBlock"]
843
+ )
844
+ raise ValueError(
845
+ f"{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}."
846
+ )
847
+ return isinstance(module, mod_types)
848
 
849
  def prepare_inputs_for_generation(
850
  self,
 
853
  inputs_embeds: Optional[torch.Tensor] = None,
854
  **kwargs: Any,
855
  ) -> Dict[str, Any]:
 
 
856
  attention_mask = kwargs["attention_mask"].bool()
857
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
858
  raise NotImplementedError(
 
872
  )
873
  else:
874
  prefix_mask = None
875
+ if inputs_embeds is not None and past_key_values is None:
876
+ model_inputs = {"inputs_embeds": inputs_embeds}
877
+ else:
878
+ model_inputs = {"input_ids": input_ids}
879
+ model_inputs.update(
880
+ {
881
+ "attention_mask": attention_mask,
882
+ "prefix_mask": prefix_mask,
883
+ "sequence_id": sequence_id,
884
+ "past_key_values": past_key_values,
885
+ "use_cache": kwargs.get("use_cache", True),
886
+ }
887
+ )
888
+ return model_inputs
889
 
890
  @staticmethod
891
  def _reorder_cache(
norm.py CHANGED
@@ -1,57 +1,122 @@
1
  from typing import Dict, List, Optional, Type, Union
2
  import torch
3
 
 
4
  def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
5
  if torch.is_autocast_enabled():
6
- if tensor.device.type == 'cuda':
7
  dtype = torch.get_autocast_gpu_dtype()
8
- elif tensor.device.type == 'cpu':
9
  dtype = torch.get_autocast_cpu_dtype()
10
  else:
11
  raise NotImplementedError()
12
  return tensor.to(dtype=dtype)
13
  return tensor
14
 
 
15
  class LPLayerNorm(torch.nn.LayerNorm):
16
 
17
- def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, elementwise_affine: bool=True, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
18
- super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def forward(self, x: torch.Tensor) -> torch.Tensor:
21
  module_device = x.device
22
  downcast_x = _cast_if_autocast_enabled(x)
23
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
24
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
 
 
 
 
 
 
25
  with torch.autocast(enabled=False, device_type=module_device.type):
26
- return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
 
 
 
 
 
 
27
 
28
- def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor]=None, eps: float=1e-05) -> torch.Tensor:
 
 
 
29
  output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
30
  if weight is not None:
31
  return output * weight
32
  return output
33
 
 
34
  class RMSNorm(torch.nn.Module):
35
 
36
- def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
 
 
 
 
 
 
 
37
  super().__init__()
38
  self.eps = eps
39
  if weight:
40
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
 
 
41
  else:
42
- self.register_parameter('weight', None)
43
 
44
  def forward(self, x: torch.Tensor) -> torch.Tensor:
45
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
46
 
 
47
  class LPRMSNorm(RMSNorm):
48
 
49
- def __init__(self, normalized_shape: Union[int, List[int], torch.Size], eps: float=1e-05, weight: bool=True, dtype: Optional[torch.dtype]=None, device: Optional[torch.device]=None):
50
- super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def forward(self, x: torch.Tensor) -> torch.Tensor:
53
  downcast_x = _cast_if_autocast_enabled(x)
54
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
 
 
 
 
55
  with torch.autocast(enabled=False, device_type=x.device.type):
56
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
57
- NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Optional, Type, Union
2
  import torch
3
 
4
+
5
  def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
6
  if torch.is_autocast_enabled():
7
+ if tensor.device.type == "cuda":
8
  dtype = torch.get_autocast_gpu_dtype()
9
+ elif tensor.device.type == "cpu":
10
  dtype = torch.get_autocast_cpu_dtype()
11
  else:
12
  raise NotImplementedError()
13
  return tensor.to(dtype=dtype)
14
  return tensor
15
 
16
+
17
  class LPLayerNorm(torch.nn.LayerNorm):
18
 
19
+ def __init__(
20
+ self,
21
+ normalized_shape: Union[int, List[int], torch.Size],
22
+ eps: float = 1e-05,
23
+ elementwise_affine: bool = True,
24
+ device: Optional[torch.device] = None,
25
+ dtype: Optional[torch.dtype] = None,
26
+ ):
27
+ super().__init__(
28
+ normalized_shape=normalized_shape,
29
+ eps=eps,
30
+ elementwise_affine=elementwise_affine,
31
+ device=device,
32
+ dtype=dtype,
33
+ )
34
 
35
  def forward(self, x: torch.Tensor) -> torch.Tensor:
36
  module_device = x.device
37
  downcast_x = _cast_if_autocast_enabled(x)
38
+ downcast_weight = (
39
+ _cast_if_autocast_enabled(self.weight)
40
+ if self.weight is not None
41
+ else self.weight
42
+ )
43
+ downcast_bias = (
44
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
45
+ )
46
  with torch.autocast(enabled=False, device_type=module_device.type):
47
+ return torch.nn.functional.layer_norm(
48
+ downcast_x,
49
+ self.normalized_shape,
50
+ downcast_weight,
51
+ downcast_bias,
52
+ self.eps,
53
+ )
54
 
55
+
56
+ def rms_norm(
57
+ x: torch.Tensor, weight: Optional[torch.Tensor] = None, eps: float = 1e-05
58
+ ) -> torch.Tensor:
59
  output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
60
  if weight is not None:
61
  return output * weight
62
  return output
63
 
64
+
65
  class RMSNorm(torch.nn.Module):
66
 
67
+ def __init__(
68
+ self,
69
+ normalized_shape: Union[int, List[int], torch.Size],
70
+ eps: float = 1e-05,
71
+ weight: bool = True,
72
+ dtype: Optional[torch.dtype] = None,
73
+ device: Optional[torch.device] = None,
74
+ ):
75
  super().__init__()
76
  self.eps = eps
77
  if weight:
78
+ self.weight = torch.nn.Parameter(
79
+ torch.ones(normalized_shape, dtype=dtype, device=device)
80
+ )
81
  else:
82
+ self.register_parameter("weight", None)
83
 
84
  def forward(self, x: torch.Tensor) -> torch.Tensor:
85
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
86
 
87
+
88
  class LPRMSNorm(RMSNorm):
89
 
90
+ def __init__(
91
+ self,
92
+ normalized_shape: Union[int, List[int], torch.Size],
93
+ eps: float = 1e-05,
94
+ weight: bool = True,
95
+ dtype: Optional[torch.dtype] = None,
96
+ device: Optional[torch.device] = None,
97
+ ):
98
+ super().__init__(
99
+ normalized_shape=normalized_shape,
100
+ eps=eps,
101
+ weight=weight,
102
+ dtype=dtype,
103
+ device=device,
104
+ )
105
 
106
  def forward(self, x: torch.Tensor) -> torch.Tensor:
107
  downcast_x = _cast_if_autocast_enabled(x)
108
+ downcast_weight = (
109
+ _cast_if_autocast_enabled(self.weight)
110
+ if self.weight is not None
111
+ else self.weight
112
+ )
113
  with torch.autocast(enabled=False, device_type=x.device.type):
114
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
115
+
116
+
117
+ NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
118
+ "layernorm": torch.nn.LayerNorm,
119
+ "low_precision_layernorm": LPLayerNorm,
120
+ "rmsnorm": RMSNorm,
121
+ "low_precision_rmsnorm": LPRMSNorm,
122
+ }
param_init_fns.py CHANGED
@@ -7,69 +7,90 @@ import torch
7
  from torch import nn
8
  from .fc import FC_CLASS_REGISTRY
9
  from .norm import NORM_CLASS_REGISTRY
 
10
  try:
11
  import transformer_engine.pytorch as te
12
  except:
13
  te = None
14
 
 
15
  def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
16
  del kwargs
17
- if hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable):
 
 
18
  module.reset_parameters()
19
 
 
20
  def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
21
- _fused = getattr(module, '_fused', None)
22
  if _fused is None:
23
- raise RuntimeError(f'Internal logic error')
24
  assert isinstance(module.weight, torch.Tensor)
25
  (dim, splits) = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
- for (s, e) in zip(splits[:-1], splits[1:]):
28
  slice_indices = [slice(None)] * module.weight.ndim
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
- def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
33
  del kwargs
34
  init_div_is_residual = init_div_is_residual
35
  if init_div_is_residual is False:
36
  div_is_residual = 1.0
37
  elif init_div_is_residual is True:
38
  div_is_residual = math.sqrt(2 * n_layers)
39
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
 
 
40
  div_is_residual = init_div_is_residual
41
  elif init_div_is_residual.isnumeric():
42
  div_is_residual = float(init_div_is_residual)
43
  else:
44
  div_is_residual = 1.0
45
- raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
 
 
46
  if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
47
- if hasattr(module, '_fused'):
48
  fused_init_helper_(module, init_fn_)
49
  else:
50
  init_fn_(module.weight)
51
  if module.bias is not None:
52
  assert isinstance(module.bias, torch.Tensor)
53
  torch.nn.init.zeros_(module.bias)
54
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
56
  module.weight.div_(div_is_residual)
57
  elif isinstance(module, nn.Embedding):
58
  if emb_init_std is not None:
59
  std = emb_init_std
60
  if std == 0:
61
- warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
  elif emb_init_uniform_lim is not None:
64
  lim = emb_init_uniform_lim
65
  if isinstance(lim, Sequence):
66
  if len(lim) > 2:
67
- raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
 
 
68
  if lim[0] == lim[1]:
69
- warnings.warn(f'Embedding layer initialized to {lim[0]}.')
70
  else:
71
  if lim == 0:
72
- warnings.warn(f'Embedding layer initialized to 0.')
73
  lim = [-lim, lim]
74
  (a, b) = lim
75
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
@@ -77,21 +98,29 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
77
  emb_init_fn_ = init_fn_
78
  emb_init_fn_(module.weight)
79
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
80
- if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
81
  torch.nn.init.ones_(module.weight)
82
- if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor):
83
  torch.nn.init.zeros_(module.bias)
84
  elif isinstance(module, nn.MultiheadAttention):
85
  if module._qkv_same_embed_dim:
86
  assert module.in_proj_weight is not None
87
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
 
 
 
 
88
  assert d_model is not None
89
  _d = d_model
90
  splits = (0, _d, 2 * _d, 3 * _d)
91
- for (s, e) in zip(splits[:-1], splits[1:]):
92
  init_fn_(module.in_proj_weight[s:e])
93
  else:
94
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
 
 
 
 
95
  assert module.in_proj_weight is None
96
  init_fn_(module.q_proj_weight)
97
  init_fn_(module.k_proj_weight)
@@ -103,7 +132,9 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
103
  if module.bias_v is not None:
104
  torch.nn.init.zeros_(module.bias_v)
105
  init_fn_(module.out_proj.weight)
106
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
 
 
107
  with torch.no_grad():
108
  module.out_proj.weight.div_(div_is_residual)
109
  if module.out_proj.bias is not None:
@@ -125,28 +156,94 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
125
  module.fc2_weight.div_(div_is_residual)
126
  else:
127
  for _ in module.parameters(recurse=False):
128
- raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
 
 
129
 
130
- def _normal_init_(std: float, mean: float=0.0) -> Callable:
 
131
  return partial(torch.nn.init.normal_, mean=mean, std=std)
132
 
133
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
134
  del kwargs
135
  init_fn_ = _normal_init_(std=std)
136
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
137
 
138
- def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
139
  del kwargs
140
  if init_std is None:
141
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
142
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
145
  del kwargs
146
  std = math.sqrt(2 / (5 * d_model))
147
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
148
 
149
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
150
  """From section 2.3.1 of GPT-NeoX-20B:
151
 
152
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -155,25 +252,129 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
155
  """
156
  del kwargs
157
  residual_div = n_layers / math.sqrt(10)
158
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
159
 
160
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
 
161
  del kwargs
162
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
166
  del kwargs
167
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
168
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
171
  del kwargs
172
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
173
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
 
 
 
 
 
 
 
 
174
 
175
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
 
 
 
 
 
 
 
 
 
 
176
  del kwargs
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
179
- MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from torch import nn
8
  from .fc import FC_CLASS_REGISTRY
9
  from .norm import NORM_CLASS_REGISTRY
10
+
11
  try:
12
  import transformer_engine.pytorch as te
13
  except:
14
  te = None
15
 
16
+
17
  def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
18
  del kwargs
19
+ if hasattr(module, "reset_parameters") and isinstance(
20
+ module.reset_parameters, Callable
21
+ ):
22
  module.reset_parameters()
23
 
24
+
25
  def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
26
+ _fused = getattr(module, "_fused", None)
27
  if _fused is None:
28
+ raise RuntimeError(f"Internal logic error")
29
  assert isinstance(module.weight, torch.Tensor)
30
  (dim, splits) = _fused
31
  splits = (0, *splits, module.weight.size(dim))
32
+ for s, e in zip(splits[:-1], splits[1:]):
33
  slice_indices = [slice(None)] * module.weight.ndim
34
  slice_indices[dim] = slice(s, e)
35
  init_fn_(module.weight[slice_indices])
36
 
37
+
38
+ def generic_param_init_fn_(
39
+ module: nn.Module,
40
+ init_fn_: Callable,
41
+ n_layers: int,
42
+ d_model: Optional[int] = None,
43
+ init_div_is_residual: Union[int, float, str, bool] = True,
44
+ emb_init_std: Optional[float] = None,
45
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
46
+ **kwargs: Any,
47
+ ) -> None:
48
  del kwargs
49
  init_div_is_residual = init_div_is_residual
50
  if init_div_is_residual is False:
51
  div_is_residual = 1.0
52
  elif init_div_is_residual is True:
53
  div_is_residual = math.sqrt(2 * n_layers)
54
+ elif isinstance(init_div_is_residual, float) or isinstance(
55
+ init_div_is_residual, int
56
+ ):
57
  div_is_residual = init_div_is_residual
58
  elif init_div_is_residual.isnumeric():
59
  div_is_residual = float(init_div_is_residual)
60
  else:
61
  div_is_residual = 1.0
62
+ raise ValueError(
63
+ f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
64
+ )
65
  if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
66
+ if hasattr(module, "_fused"):
67
  fused_init_helper_(module, init_fn_)
68
  else:
69
  init_fn_(module.weight)
70
  if module.bias is not None:
71
  assert isinstance(module.bias, torch.Tensor)
72
  torch.nn.init.zeros_(module.bias)
73
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
74
  with torch.no_grad():
75
  module.weight.div_(div_is_residual)
76
  elif isinstance(module, nn.Embedding):
77
  if emb_init_std is not None:
78
  std = emb_init_std
79
  if std == 0:
80
+ warnings.warn(f"Embedding layer initialized to 0.")
81
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
82
  elif emb_init_uniform_lim is not None:
83
  lim = emb_init_uniform_lim
84
  if isinstance(lim, Sequence):
85
  if len(lim) > 2:
86
+ raise ValueError(
87
+ f"Uniform init requires a min and a max limit. User input: {lim}."
88
+ )
89
  if lim[0] == lim[1]:
90
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
91
  else:
92
  if lim == 0:
93
+ warnings.warn(f"Embedding layer initialized to 0.")
94
  lim = [-lim, lim]
95
  (a, b) = lim
96
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
 
98
  emb_init_fn_ = init_fn_
99
  emb_init_fn_(module.weight)
100
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
101
+ if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor):
102
  torch.nn.init.ones_(module.weight)
103
+ if hasattr(module, "bias") and isinstance(module.bias, torch.Tensor):
104
  torch.nn.init.zeros_(module.bias)
105
  elif isinstance(module, nn.MultiheadAttention):
106
  if module._qkv_same_embed_dim:
107
  assert module.in_proj_weight is not None
108
+ assert (
109
+ module.q_proj_weight is None
110
+ and module.k_proj_weight is None
111
+ and (module.v_proj_weight is None)
112
+ )
113
  assert d_model is not None
114
  _d = d_model
115
  splits = (0, _d, 2 * _d, 3 * _d)
116
+ for s, e in zip(splits[:-1], splits[1:]):
117
  init_fn_(module.in_proj_weight[s:e])
118
  else:
119
+ assert (
120
+ module.q_proj_weight is not None
121
+ and module.k_proj_weight is not None
122
+ and (module.v_proj_weight is not None)
123
+ )
124
  assert module.in_proj_weight is None
125
  init_fn_(module.q_proj_weight)
126
  init_fn_(module.k_proj_weight)
 
132
  if module.bias_v is not None:
133
  torch.nn.init.zeros_(module.bias_v)
134
  init_fn_(module.out_proj.weight)
135
+ if init_div_is_residual is not False and getattr(
136
+ module.out_proj, "_is_residual", False
137
+ ):
138
  with torch.no_grad():
139
  module.out_proj.weight.div_(div_is_residual)
140
  if module.out_proj.bias is not None:
 
156
  module.fc2_weight.div_(div_is_residual)
157
  else:
158
  for _ in module.parameters(recurse=False):
159
+ raise NotImplementedError(
160
+ f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
161
+ )
162
 
163
+
164
+ def _normal_init_(std: float, mean: float = 0.0) -> Callable:
165
  return partial(torch.nn.init.normal_, mean=mean, std=std)
166
 
167
+
168
+ def _normal_param_init_fn_(
169
+ module: nn.Module,
170
+ std: float,
171
+ n_layers: int,
172
+ d_model: Optional[int] = None,
173
+ init_div_is_residual: Union[int, float, str, bool] = True,
174
+ emb_init_std: Optional[float] = None,
175
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
176
+ **kwargs: Any,
177
+ ) -> None:
178
  del kwargs
179
  init_fn_ = _normal_init_(std=std)
180
+ generic_param_init_fn_(
181
+ module=module,
182
+ init_fn_=init_fn_,
183
+ d_model=d_model,
184
+ n_layers=n_layers,
185
+ init_div_is_residual=init_div_is_residual,
186
+ emb_init_std=emb_init_std,
187
+ emb_init_uniform_lim=emb_init_uniform_lim,
188
+ )
189
 
190
+
191
+ def baseline_param_init_fn_(
192
+ module: nn.Module,
193
+ init_std: Optional[float],
194
+ n_layers: int,
195
+ d_model: Optional[int] = None,
196
+ init_div_is_residual: Union[int, float, str, bool] = True,
197
+ emb_init_std: Optional[float] = None,
198
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
199
+ **kwargs: Any,
200
+ ) -> None:
201
  del kwargs
202
  if init_std is None:
203
+ raise ValueError(
204
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
205
+ )
206
+ _normal_param_init_fn_(
207
+ module=module,
208
+ std=init_std,
209
+ d_model=d_model,
210
+ n_layers=n_layers,
211
+ init_div_is_residual=init_div_is_residual,
212
+ emb_init_std=emb_init_std,
213
+ emb_init_uniform_lim=emb_init_uniform_lim,
214
+ )
215
+
216
 
217
+ def small_param_init_fn_(
218
+ module: nn.Module,
219
+ n_layers: int,
220
+ d_model: int,
221
+ init_div_is_residual: Union[int, float, str, bool] = True,
222
+ emb_init_std: Optional[float] = None,
223
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
224
+ **kwargs: Any,
225
+ ) -> None:
226
  del kwargs
227
  std = math.sqrt(2 / (5 * d_model))
228
+ _normal_param_init_fn_(
229
+ module=module,
230
+ std=std,
231
+ d_model=d_model,
232
+ n_layers=n_layers,
233
+ init_div_is_residual=init_div_is_residual,
234
+ emb_init_std=emb_init_std,
235
+ emb_init_uniform_lim=emb_init_uniform_lim,
236
+ )
237
 
238
+
239
+ def neox_param_init_fn_(
240
+ module: nn.Module,
241
+ n_layers: int,
242
+ d_model: int,
243
+ emb_init_std: Optional[float] = None,
244
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
245
+ **kwargs: Any,
246
+ ) -> None:
247
  """From section 2.3.1 of GPT-NeoX-20B:
248
 
249
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
252
  """
253
  del kwargs
254
  residual_div = n_layers / math.sqrt(10)
255
+ small_param_init_fn_(
256
+ module=module,
257
+ d_model=d_model,
258
+ n_layers=n_layers,
259
+ init_div_is_residual=residual_div,
260
+ emb_init_std=emb_init_std,
261
+ emb_init_uniform_lim=emb_init_uniform_lim,
262
+ )
263
+
264
 
265
+ def kaiming_uniform_param_init_fn_(
266
+ module: nn.Module,
267
+ n_layers: int,
268
+ d_model: Optional[int] = None,
269
+ init_div_is_residual: Union[int, float, str, bool] = True,
270
+ emb_init_std: Optional[float] = None,
271
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
272
+ init_gain: float = 0,
273
+ fan_mode: str = "fan_in",
274
+ init_nonlinearity: str = "leaky_relu",
275
+ **kwargs: Any,
276
+ ) -> None:
277
  del kwargs
278
+ kaiming_uniform_ = partial(
279
+ nn.init.kaiming_uniform_,
280
+ a=init_gain,
281
+ mode=fan_mode,
282
+ nonlinearity=init_nonlinearity,
283
+ )
284
+ generic_param_init_fn_(
285
+ module=module,
286
+ init_fn_=kaiming_uniform_,
287
+ d_model=d_model,
288
+ n_layers=n_layers,
289
+ init_div_is_residual=init_div_is_residual,
290
+ emb_init_std=emb_init_std,
291
+ emb_init_uniform_lim=emb_init_uniform_lim,
292
+ )
293
 
294
+
295
+ def kaiming_normal_param_init_fn_(
296
+ module: nn.Module,
297
+ n_layers: int,
298
+ d_model: Optional[int] = None,
299
+ init_div_is_residual: Union[int, float, str, bool] = True,
300
+ emb_init_std: Optional[float] = None,
301
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
302
+ init_gain: float = 0,
303
+ fan_mode: str = "fan_in",
304
+ init_nonlinearity: str = "leaky_relu",
305
+ **kwargs: Any,
306
+ ) -> None:
307
  del kwargs
308
+ kaiming_normal_ = partial(
309
+ torch.nn.init.kaiming_normal_,
310
+ a=init_gain,
311
+ mode=fan_mode,
312
+ nonlinearity=init_nonlinearity,
313
+ )
314
+ generic_param_init_fn_(
315
+ module=module,
316
+ init_fn_=kaiming_normal_,
317
+ d_model=d_model,
318
+ n_layers=n_layers,
319
+ init_div_is_residual=init_div_is_residual,
320
+ emb_init_std=emb_init_std,
321
+ emb_init_uniform_lim=emb_init_uniform_lim,
322
+ )
323
+
324
 
325
+ def xavier_uniform_param_init_fn_(
326
+ module: nn.Module,
327
+ n_layers: int,
328
+ d_model: Optional[int] = None,
329
+ init_div_is_residual: Union[int, float, str, bool] = True,
330
+ emb_init_std: Optional[float] = None,
331
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
332
+ init_gain: float = 0,
333
+ **kwargs: Any,
334
+ ) -> None:
335
  del kwargs
336
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
337
+ generic_param_init_fn_(
338
+ module=module,
339
+ init_fn_=xavier_uniform_,
340
+ d_model=d_model,
341
+ n_layers=n_layers,
342
+ init_div_is_residual=init_div_is_residual,
343
+ emb_init_std=emb_init_std,
344
+ emb_init_uniform_lim=emb_init_uniform_lim,
345
+ )
346
 
347
+
348
+ def xavier_normal_param_init_fn_(
349
+ module: nn.Module,
350
+ n_layers: int,
351
+ d_model: Optional[int] = None,
352
+ init_div_is_residual: Union[int, float, str, bool] = True,
353
+ emb_init_std: Optional[float] = None,
354
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
355
+ init_gain: float = 0,
356
+ **kwargs: Any,
357
+ ) -> None:
358
  del kwargs
359
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
360
+ generic_param_init_fn_(
361
+ module=module,
362
+ init_fn_=xavier_normal_,
363
+ d_model=d_model,
364
+ n_layers=n_layers,
365
+ init_div_is_residual=init_div_is_residual,
366
+ emb_init_std=emb_init_std,
367
+ emb_init_uniform_lim=emb_init_uniform_lim,
368
+ )
369
+
370
+
371
+ MODEL_INIT_REGISTRY = {
372
+ "default_": torch_default_param_init_fn_,
373
+ "baseline_": baseline_param_init_fn_,
374
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
375
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
376
+ "neox_init_": neox_param_init_fn_,
377
+ "small_init_": small_param_init_fn_,
378
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
379
+ "xavier_normal_": xavier_normal_param_init_fn_,
380
+ }
warnings.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class VersionedDeprecationWarning(DeprecationWarning):
2
+ """A custom deprecation warning class that includes version information.
3
+ Attributes:
4
+ message (str): The deprecation message describing why the feature is deprecated.
5
+ remove_version (str): The version in which the feature will be removed.
6
+ Example:
7
+ >>> def deprecated_function():
8
+ ... warnings.warn(
9
+ ... VersionedDeprecationWarning(
10
+ ... "Function XYZ is deprecated.",
11
+ ... after_version="2.0.0"
12
+ ... )
13
+ ... )
14
+ ...
15
+ >>> deprecated_function()
16
+ DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0.
17
+ """
18
+
19
+ def __init__(self, message: str, remove_version: str) -> None:
20
+ super().__init__(message + f" It will be removed in version {remove_version}.")