CLASS torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation=<function relu>, custom_encoder=None, custom_decoder=None,
layer_norm_eps=1e-05, batch_first=False, norm_first=False, device=None,
dtype=None)
Parameters
d_model
– the number of expected features in the encoder/decoder inputs (default=512).nhead
– the number of heads in the multiheadattention models (default=8).num_encoder_layers
– the number of sub-encoder-layers in the encoder (default=6).num_decoder_layers
– the number of sub-decoder-layers in the decoder (default=6).dim_feedforward
– the dimension of the feedforward network model (default=2048).dropout
– the dropout value (default=0.1).activation
– the activation function of encoder/decoder intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu
custom_encoder
– custom encoder (default=None).custom_decoder
– custom decoder (default=None).layer_norm_eps
– the eps value in layer normalization components (default=1e-5).batch_first
– If True
, then the input and output tensors are provided as (batch, seq, feature). Default: False
(seq, batch, feature).norm_first
– if True
, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default: False
(after).forward(src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None)
Parameters
src
– the sequence to the encoder (required).tgt
– the sequence to the decoder (required).src_mask
– the additive mask for the src sequence (optional).tgt_mask
– the additive mask for the tgt sequence (optional).memory_mask
– the additive mask for the encoder output (optional).src_key_padding_mask
– the ByteTensor mask for src keys per batch (optional).tgt_key_padding_mask
– the ByteTensor mask for tgt keys per batch (optional).memory_key_padding_mask
– the ByteTensor mask for memory keys per batch (optional).Shape
src
: (S,E) for unbatched input, (S,N,E) if batch_first=False
or (N,S,E) if batch_first=True
.tgt
:(T,E) for unbatched input, (T,N,E) if batch_first=False
or (N,T,E) if batch_first=True
.src_mask
: (S,S) or (N⋅num_heads,S,S).tgt_mask
: (T,T) or (N⋅num_heads,T,T).memory_mask
: (T,S).src_key_padding_mask
: (S) for unbatched input otherwise (N,S).tgt_key_padding_mask
: (T) for unbatched input otherwise (N,T).memory_key_padding_mask
:(S) for unbatched input otherwise (N,S).output
: (T,E) for unbatched input, (T,N,E) if batch_first=False or (N,T,E) if batch_first=True.