leoxiaobin YenChunChen commited on
Commit
6065b7a
1 Parent(s): b161793

Fix generation when `repetition_penalty` is activated (#57)

Browse files

- make sure input_ids do not contain negative numbers (indicating images) after they are no longer needed (5905c926df4db18660da263a9777998ca66a14fe)


Co-authored-by: Yen-Chun Chen <[email protected]>

Files changed (1) hide show
  1. image_embedding_phi3_v.py +10 -1
image_embedding_phi3_v.py CHANGED
@@ -12,6 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
 
15
 
16
  import torch
17
  from torch import nn
@@ -191,7 +192,15 @@ class Phi3ImageEmbedding(nn.Module):
191
  # positions for image tokens
192
  positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
193
  has_image = len(positions[0].tolist()) > 0
194
- input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
 
 
 
 
 
 
 
 
195
  hidden_states = self.wte(input_ids)
196
 
197
  if has_image:
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ import warnings
16
 
17
  import torch
18
  from torch import nn
 
192
  # positions for image tokens
193
  positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
194
  has_image = len(positions[0].tolist()) > 0
195
+ # input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
196
+ input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
197
+ warnings.warn(
198
+ "Phi-3-V modifies `input_ids` in-place and the tokens indicating images will be "
199
+ "removed after model forward. If your workflow requires multiple forward passes on "
200
+ "the same `input_ids`, please make a copy of `input_ids` before passing it to the "
201
+ "model."
202
+ )
203
+
204
  hidden_states = self.wte(input_ids)
205
 
206
  if has_image: