35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class GatedMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
activation=F.silu,
|
|
bias=False,
|
|
multiple_of=128,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
out_features = out_features if out_features is not None else in_features
|
|
hidden_features = (
|
|
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
)
|
|
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
|
self.activation = activation
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
|
|
|
def forward(self, x):
|
|
y = self.fc1(x)
|
|
y, gate = y.chunk(2, dim=-1)
|
|
y = y * self.activation(gate)
|
|
y = self.fc2(y)
|
|
return y
|