1. The following model_kwargs are not used by the model: ['encoder_hidden_states', 'encoder_attention_mask'] (note: typos in the generate arguments will also show up in this list)
使用text_decoder就出现上述错误,这是由于transformers版本不兼容导致的
- from transformers import AutoModel, AutoConfig, BertGenerationDecoder
- decoder_config = AutoConfig.from_pretrained(args['text_checkpoint'])
-
- text_decoder = BertGenerationDecoder(config=decoder_config)
-
- output = self.text_decoder.generate(input_ids=cls_input_ids,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- max_length=self.args['max_seq_length'],
- do_sample=True,
- num_beams=self.args['beam_size'],
- length_penalty=1.0, use_cache=True,
- )
解决办法:将transformer的版本换到以下范围, 4.15.0<=transformers<4.22.0,transformers>=4.25.0
比如:pip install transformers==4.25.1 or pip install transformers==4.20.1
2. No module named 'transformers.generation_beam_constraints' (其中transformers==4.28.1)
(1)解决办法
将:from transformers import generation_beam_constraints
改为:from transformers.generation import beam_constraints
(2)其他例子
有问题的代码:
- # 可以在transformers == 4.23.1版本上面运行
- from transformers.generation_beam_constraints import Constraint
- from transformers.generation_beam_search import BeamScorer, BeamSearchScorer
- from transformers.generation_logits_process import (
- EncoderNoRepeatNGramLogitsProcessor,
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
- HammingDiversityLogitsProcessor,
- InfNanRemoveLogitsProcessor,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
- NoBadWordsLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
- PrefixConstrainedLogitsProcessor,
- RepetitionPenaltyLogitsProcessor,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
-
- )
- from transformers.generation_stopping_criteria import (
- MaxLengthCriteria,
- MaxTimeCriteria,
- StoppingCriteria,
- StoppingCriteriaList,
- validate_stopping_criteria,
- )
修正后的代码:
- # 可以在transformers == 4.28.1版本上面运行
- from transformers.generation.beam_constraints import Constraint
- from transformers.generation.beam_search import BeamScorer, BeamSearchScorer
- from transformers.generation.logits_process import (
- EncoderNoRepeatNGramLogitsProcessor,
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
- HammingDiversityLogitsProcessor,
- InfNanRemoveLogitsProcessor,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
- NoBadWordsLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
- PrefixConstrainedLogitsProcessor,
- RepetitionPenaltyLogitsProcessor,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
-
- )
- from transformers.generation.stopping_criteria import (
- MaxLengthCriteria,
- MaxTimeCriteria,
- StoppingCriteria,
- StoppingCriteriaList,
- validate_stopping_criteria,
- )