torch.nn.Module
该模块是所有神经网络模块的基类。
使用方法:在构造函数中声明各个层,在 forward 方法中将这些层连接起来(前向传播)。
forward 方法
举例:
class Head(nn.Module):
""" 单头自注意力机制 """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x) # (B,T,C)
q = self.query(x) # (B,T,C)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,C)
out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
return out
怎么看:
在 PyTorch 中,forward
方法是定义一个模型如何处理输入数据的地方。forward
方法内部的代码流程就是数据流向的具体描述。我们可以通过逐步阅读和理解每一行代码来看出数据是如何被处理的,以及各个层和操作的执行顺序。现在,我将详细解释你提供的 forward
方法中的操作流程:
-
获取输入维度:
B, T, C = x.shape
这行代码是从输入数据
x
中提取维度。B
代表批次大小(batch size),T
代表序列长度(time steps 或 tokens),C
代表每个时间步的特征或嵌入维度。 -
生成 Key、Query 和 Value:
k = self.key(x) # (B,T,C) q = self.query(x) # (B,T,C) v = self.value(x) # (B,T,C)
这里,输入数据
x
被分别传递到三个不同的线性层(nn.Linear
),生成自注意力机制中的 key(k
)、query(q
)和 value(v
)。这三个变量都保持和输入相同的维度。 -
计算注意力得分(Affinity scores):
wei = q @ k.transpose(-2, -1) * C**-0.5
在这一步,
query
矩阵和key
矩阵的转置进行点积运算,得到每个 query 对所有 keys 的相似度得分,再通过除以 (C^{-0.5}) 进行缩放(这通常是为了防止梯度过小)。 -
应用 Mask:
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
这里使用一个下三角矩阵(
tril
)来将不应关注的位置(通常是未来的时间步)置为负无穷(-inf
)。这样在下一步的 softmax 计算中,这些位置的权重会接近于零。 -
Softmax 归一化:
wei = F.softmax(wei, dim=-1)
使用 softmax 函数对
wei
每行进行归一化,这样每行的所有元素加起来为 1,表示概率分布。 -
应用 Dropout:
wei = self.dropout(wei)
通过 dropout 随机地将一些元素置零,这是一种常见的正则化手段,可以帮助防止模型过拟合。
-
加权聚合 Value:
out = wei @ v
最后,使用上述得到的权重矩阵
wei
和 value 矩阵v
进行点积运算,实现了加权聚合。输出的out
仍保持原有的维度(B, T, C)
,每个元素是 value 的加权和,权重由相应的注意力得分确定。
在理解代码的过程中,关键在于跟踪每个变量的维度和它们之间的运算关系,从而把握数据是如何流动和转换的。希望这个解释能帮助你更好地理解 PyTorch 中的 forward
方法。如果还有其他问题,随时告诉我!
网络资源
本文作者:Maeiee
本文链接:torch.nn.Module
版权声明:如无特别声明,本文即为原创文章,版权归 Maeiee 所有,未经允许不得转载!
喜欢我文章的朋友请随缘打赏,鼓励我创作更多更好的作品!