94 lines
3.7 KiB
Python
94 lines
3.7 KiB
Python
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(
|
|
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
|
):
|
|
"""
|
|
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
|
|
|
This Block has a slightly different structure compared to a regular
|
|
prenorm Transformer block.
|
|
The standard block is: LN -> MHA/MLP -> Add.
|
|
[Ref: https://arxiv.org/abs/2002.04745]
|
|
Here we have: Add -> LN -> Mixer, returning both
|
|
the hidden_states (output of the mixer) and the residual.
|
|
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
|
The residual needs to be provided (except for the very first block).
|
|
"""
|
|
super().__init__()
|
|
self.residual_in_fp32 = residual_in_fp32
|
|
self.fused_add_norm = fused_add_norm
|
|
self.norm = norm_cls(dim)
|
|
self.mixer = mixer_cls(dim)
|
|
if mlp_cls is not nn.Identity:
|
|
self.norm2 = norm_cls(dim)
|
|
self.mlp = mlp_cls(dim)
|
|
else:
|
|
self.mlp = None
|
|
if self.fused_add_norm:
|
|
assert RMSNorm is not None, "RMSNorm import fails"
|
|
assert isinstance(
|
|
self.norm, (nn.LayerNorm, RMSNorm)
|
|
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
|
|
|
def forward(
|
|
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
|
|
):
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
hidden_states: the sequence to the encoder layer (required).
|
|
residual: hidden_states = Mixer(LN(residual))
|
|
"""
|
|
if not self.fused_add_norm:
|
|
# residual残差连接
|
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
|
# 调用norm
|
|
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
|
if self.residual_in_fp32:
|
|
residual = residual.to(torch.float32)
|
|
else:
|
|
hidden_states, residual = layer_norm_fn(
|
|
hidden_states,
|
|
self.norm.weight,
|
|
self.norm.bias,
|
|
residual=residual,
|
|
prenorm=True,
|
|
residual_in_fp32=self.residual_in_fp32,
|
|
eps=self.norm.eps,
|
|
is_rms_norm=isinstance(self.norm, RMSNorm)
|
|
)
|
|
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
|
|
|
|
if self.mlp is not None:
|
|
if not self.fused_add_norm:
|
|
residual = hidden_states + residual
|
|
residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
|
if self.residual_in_fp32:
|
|
residual = residual.to(torch.float32)
|
|
else:
|
|
hidden_states, residual = layer_norm_fn(
|
|
hidden_states,
|
|
self.norm2.weight,
|
|
self.norm2.bias,
|
|
residual=residual,
|
|
prenorm=True,
|
|
residual_in_fp32=self.residual_in_fp32,
|
|
eps=self.norm2.eps,
|
|
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
|
)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
return hidden_states, residual
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|