{"id":526,"date":"2025-04-17T12:02:00","date_gmt":"2025-04-17T04:02:00","guid":{"rendered":"https:\/\/www.vcoco.top\/?p=526"},"modified":"2025-08-21T21:51:17","modified_gmt":"2025-08-21T13:51:17","slug":"%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0-%e6%89%8b%e6%92%95transformer%ef%bc%88%e6%89%be%e5%b7%a5%e4%bd%9c%e7%af%87%ef%bc%89","status":"publish","type":"post","link":"https:\/\/www.vcoco.top\/index.php\/2025\/04\/17\/%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0-%e6%89%8b%e6%92%95transformer%ef%bc%88%e6%89%be%e5%b7%a5%e4%bd%9c%e7%af%87%ef%bc%89\/","title":{"rendered":"\u6df1\u5ea6\u5b66\u4e60\u2014\u2014\u624b\u6495Transformer\uff08\u627e\u5de5\u4f5c\u7bc7\uff09"},"content":{"rendered":"\n<ul class=\"wp-block-list\">\n<li>\u53c2\u8003\u94fe\u63a5 <a rel=\"noreferrer noopener\" href=\"https:\/\/hwcoder.top\/Manual-Coding-1\" target=\"_blank\">https:\/\/hwcoder.top\/Manual-Coding-1<\/a><\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\">\u5355\u5934\u6ce8\u610f\u529b\u673a\u5236<\/h3>\n\n\n\n<p>Q, K, V\u7684\u7406\u89e3\u3002<\/p>\n\n\n\n<p>\u5047\u8bbeQ\u7531\u5982\u4e0b\u77e9\u9635\u7ec4\u6210\uff08\u4e0d\u8003\u8651batchsize\uff09\uff1a<br>Q[0]\uff1a\u8bcd1\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>Q[1]\uff1a\u8bcd2\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>\u540c\u7406K\uff1a<\/p>\n\n\n\n<p>K[0]\uff1a\u8bcd1\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>K[1]\uff1a\u8bcd2\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>\u540c\u7406V\uff1a<\/p>\n\n\n\n<p>V[0]\uff1a\u8bcd1\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>V[1]\uff1a\u8bcd2\uff1a[dim1, dim2, dim3&#8230;]<\/p>\n\n\n\n<p>\u90a3\u4e48Q @ V.T \u5c31\u662f\u5728\u8be5\u8bcd\u7684\u53c9\u4e58\u548c\u3002\u5373\u8ba1\u7b97\u8be5\u8bcd\u7684\u7279\u5f81\u3002\u5982\u679c\u662fK\u7684\u7b2c\u4e00\u884c\u548cV\u7684\u7b2c\u4e00\u5217\u3002\u90a3\u4e48\u5c31\u662f\u8ba1\u7b97\u81ea\u5df1\u7684\u7279\u5f81\u3002\u5982\u679c\u662fV\u7684\u7b2ci\u5217\u5219\u662f\u8ba1\u7b97\u7b2c\u4e00\u4e2a\u8bcd\u548c\u7b2ci\u4e2a\u8bcd\u4e4b\u95f4\u7684\u7279\u5f81\u3002<\/p>\n\n\n\n<p>Q @ K\u7684\u7ed3\u679c\u77e9\u9635\uff1a<\/p>\n\n\n\n<p>QK[0]\uff1a\u8bcd1\uff1a[\u8bcd1, \u8bcd2, \u8bcd3&#8230;]<\/p>\n\n\n\n<p>QK[1]\uff1a\u8bcd2\uff1a[\u8bcd1, \u8bcd2, \u8bcd3&#8230;]<\/p>\n\n\n\n<p>\u6b64\u65f6\u518d @ V\u5c31\u662f\u5bf9\u6240\u6709\u8bcd\u603b\u548c\u4e00\u4e0b\u3002QK[0,:] * K[:,0] = \u8bcd1*\u8bcd1+\u8bcd2*\u8bcd2&#8230;<\/p>\n\n\n\n<h4 class=\"wp-block-heading\"><strong>Softmax\u8ba1\u7b97<\/strong><\/h4>\n\n\n\n<p><strong>\u76f4\u89c2\u4f8b\u5b50\u200b\u200b<\/strong><\/p>\n\n\n\n<p>\u5047\u8bbe&nbsp;<code>batch_size=1<\/code>\u3001<code>num_heads=1<\/code>\u3001<code>seq_len=3<\/code>\uff0c<code>qk<\/code>&nbsp;\u77e9\u9635\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>qk = &#91;\n    &#91;1.0, 0.5, 0.2],  # \u7b2c1\u4e2a\u67e5\u8be2\u5bf93\u4e2a\u952e\u7684\u5206\u6570\n    &#91;0.3, 1.2, 0.8],  # \u7b2c2\u4e2a\u67e5\u8be2\u5bf93\u4e2a\u952e\u7684\u5206\u6570\n    &#91;0.7, 0.1, 1.5]   # \u7b2c3\u4e2a\u67e5\u8be2\u5bf93\u4e2a\u952e\u7684\u5206\u6570\n]<\/code><\/pre>\n\n\n\n<p>\u6267\u884c&nbsp;<code>softmax(dim=-1)<\/code>&nbsp;\u540e\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>softmax_qk = &#91;\n    &#91;0.50, 0.30, 0.20],  # \u7b2c1\u4e2a\u67e5\u8be2\u7684\u6743\u91cd\u5206\u914d\n    &#91;0.15, 0.50, 0.35],  # \u7b2c2\u4e2a\u67e5\u8be2\u7684\u6743\u91cd\u5206\u914d\n    &#91;0.30, 0.15, 0.55]   # \u7b2c3\u4e2a\u67e5\u8be2\u7684\u6743\u91cd\u5206\u914d\n]<\/code><\/pre>\n\n\n\n<p>\u6bcf\u4e00\u884c\u7684\u548c\u4e3a 1\uff0c\u7b26\u5408\u6ce8\u610f\u529b\u6743\u91cd\u7684\u5b9a\u4e49\u3002<\/p>\n\n\n\n<p><strong>\u7ef4\u5ea6\u89e3\u91ca<\/strong><\/p>\n\n\n\n<p>torch.softmax\u5b98\u65b9\u7684\u89e3\u91ca\uff1a<\/p>\n\n\n\n<div class=\"wp-block-group is-vertical is-layout-flex wp-container-core-group-is-layout-8cf370e7 wp-block-group-is-layout-flex\">\n<p class=\"has-small-font-size\">\u53c2\u6570<\/p>\n\n\n\n<p class=\"has-small-font-size\"><strong>dim<\/strong>\uff08<a href=\"https:\/\/docs.python.org\/3\/library\/functions.html#int\"><em>int<\/em><\/a>\uff09\u2013 \u8ba1\u7b97 Softmax \u7684\u7ef4\u5ea6\uff08\u56e0\u6b64 dim \u4e0a\u6bcf\u4e2a\u5207\u7247\u7684\u603b\u548c\u4e3a 1\uff09\u3002<\/p>\n<\/div>\n\n\n\n<p class=\"has-luminous-vivid-amber-color has-text-color\">Question\uff1a-1\u4e0d\u662f\u6700\u540e\u4e00\u7ef4\u5417\uff1f\u5728\u77e9\u9635\u4e2d\u6700\u540e\u4e00\u7ef4\u4e0d\u5c31\u662f\u5217\u5417\uff1f\u4e0d\u662f\u5bf9\u8fd9\u4e00\u5217\u8ba1\u7b97\u8ba9\u8fd9\u4e00\u5217\u52a0\u548c\u4e3a1\u5417\uff1f<\/p>\n\n\n\n<p>\u5047\u8bbe\u6709\u4e00\u4e2a\u77e9\u9635\uff082D\u5f20\u91cf\uff09\uff1a<\/p>\n\n\n\n<div class=\"wp-block-group is-vertical is-layout-flex wp-container-core-group-is-layout-8cf370e7 wp-block-group-is-layout-flex\">\n<pre class=\"wp-block-code\"><code>A = &#91;\n    &#91;a, b, c],  # \u7b2c1\u884c\n    &#91;d, e, f]   # \u7b2c2\u884c\n]<\/code><\/pre>\n\n\n\n<p><strong>\u6700\u540e\u4e00\u7ef4\uff08<code>dim=-1<\/code>\uff09\u662f\u201c\u5217\u201d\u200b<\/strong>\u200b\uff1a<code>a, b, c<\/code>&nbsp;\u662f\u7b2c\u4e00\u884c\u76843\u4e2a\u5217\u5143\u7d20\u3002\u786e\u5b9e\u662f\u5bf9\u6bcf\u4e00\u884c\u7684\u6240\u6709\u5217\u8ba1\u7b97\u4e86softmax\u3002<\/p>\n\n\n\n<p>A[0][0]=a, A[0][1]=b, A[0][2]=c.\u6211\u4eec\u4e00\u822c\u662f\u60f3\u5bf9\u6700\u540e\u4e00\u7ef4softmax\uff0c\u4e5f\u5c31\u662fA[0][0-3]\u8ba1\u7b97\u3002\u90a3\u5c31\u5176\u5b9e\u662f\u5982\u9898\u610f\u4e00\u6837\u3002torch.softmax(A, dim=-1)\u5c31\u884c\u3002<\/p>\n<\/div>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<h4 class=\"wp-block-heading\">\u4ee3\u7801\u5b9e\u73b0<\/h4>\n\n\n\n<pre class=\"wp-block-code\"><code>class SelfAttention(nn.Module):\n    def __init__(self, hidden_size,):\n        super().__init__()\n        self.Q=nn.Linear(hidden_size, hidden_size)\n        self.K=nn.Linear(hidden_size, hidden_size)\n        self.V=nn.Linear(hidden_size, hidden_size)\n        self.linear=nn.Linear(hidden_size, hidden_size)\n    def forward(self, x, causal_mask=None, pad_mask=None):\n        bs, hd = x.shape&#91;0], x.shape&#91;2]\n        q=self.Q(x)\n        k=self.K(x)\n        v=self.V(x) # (bs, len, dim)\n\n        # (len, dim) @ (dim, len) = (len, len)\n        qk = q @ k.transpose(-1, -2) \/ (hd**0.5)\n\n        if causal_mask:\n            qk = qk * causal_mask\n        if pad_mask:\n            qk = qk * pad_mask\n        \n        qk = torch.softmax(qk, dim=-1)\n\n        res = qk @ v \n\n        res=self.linear(res)\n\n        return res\n\n\nx=torch.rand(2, 192, 64)\n# print(x.transpose(-1, -2).shape)\nmodel=SelfAttention(64 )\nres=model(x)\nres.shape\n<\/code><\/pre>\n\n\n\n<h3 class=\"wp-block-heading\">\u591a\u5934\u6ce8\u610f\u529b<\/h3>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u8fd9\u91cc\u8981\u6ce8\u610f\u8ba1\u7b97\u7684\u65f6\u5019shape\u662f<code>(batch_size, nums_head, seq_len, dim)<\/code>\u800c\u4e0d\u662f<code>(batch_size, seq_len,  nums_head, dim)<\/code>\u5c31\u884c<\/li>\n<\/ul>\n\n\n\n<h4 class=\"wp-block-heading\">\u4ee3\u7801\u5b9e\u73b0<\/h4>\n\n\n\n<pre class=\"wp-block-code\"><code>class MultiHeadCrossAttention(nn.Module):\r\n    def __init__(self, hidden_size, head_nums):\r\n        super().__init__()\r\n        self.Q=nn.Linear(hidden_size, hidden_size)\r\n        self.K=nn.Linear(hidden_size, hidden_size)\r\n        self.V=nn.Linear(hidden_size, hidden_size)\r\n        self.linear=nn.Linear(hidden_size, hidden_size)\r\n        self.head_nums=head_nums\r\n    def forward(self, q, key_value, causal_mask=None, pad_mask=None):\r\n        (bs, N, hd) = q.shape\r\n        hd\/=self.head_nums\r\n        q=self.Q(q)\r\n        k=self.K(key_value)\r\n        v=self.V(key_value) # (bs, len, dim)\r\n \r\n        # (bs, len, head_nums, dim\/head_nums) -> (bs,head_nums,len,dim\/head_nums)\r\n        q=q.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n        k=k.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n        v=v.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n \r\n        # (len, dim) @ (dim, len) = (len, len)\r\n        qk = q @ k.transpose(-1, -2) \/ (hd**0.5)\r\n \r\n        if causal_mask:\r\n            qk = qk * causal_mask\r\n        if pad_mask:\r\n            qk = qk * pad_mask\r\n \r\n        qk = torch.softmax(qk, dim=-1)\r\n \r\n        res = qk @ v  #(bs,head_nums,len,dim\/head_nums)\r\n        res = res.transpose(1,2).reshape(bs, N, -1)\r\n \r\n        res=self.linear(res)\r\n \r\n        return res\n\nclass MultiHeadSelfAttention(nn.Module):\r\n    def __init__(self, hidden_size, head_nums):\r\n        super().__init__()\r\n        self.head_nums=head_nums\r\n        self.hidden_size=hidden_size\r\n        self.Q=nn.Linear(hidden_size, hidden_size)\r\n        self.K=nn.Linear(hidden_size, hidden_size)\r\n        self.V=nn.Linear(hidden_size, hidden_size)\r\n        self.linear=nn.Linear(hidden_size, hidden_size)\r\n    def forward(self, x, causal_mask=None, pad_mask=None):\r\n        (bs, N, hd) = x.shape\r\n        q=self.Q(x)\r\n        k=self.K(x)\r\n        v=self.V(x) # (bs, len, dim)\r\n        hd\/=self.head_nums\r\n        q=q.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n        k=k.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n        v=v.reshape(bs, N, self.head_nums, -1).transpose(1,2)\r\n \r\n        # (len, dim) @ (dim, len) = (len, len)\r\n        qk = q @ k.transpose(-1, -2) \/ (hd**0.5)\r\n \r\n        if causal_mask:\r\n            qk = qk * causal_mask\r\n        if pad_mask:\r\n            qk = qk * pad_mask\r\n        \r\n        qk = torch.softmax(qk, dim=-1)\r\n \r\n        res = qk @ v \r\n        res = res.transpose(1,2).reshape(bs, N, -1)\r\n \r\n        res=self.linear(res)\r\n \r\n        return res\r\n \r\n\r\n \r\nif __name__=='__main__':\r\n    '''\r\n    multi-head attention\r\n    '''\r\n    q=torch.rand(2, 192, 64)\r\n    encoder_output=torch.rand(2, 192, 64)\r\n    # print(x.transpose(-1, -2).shape)\r\n    model=MultiHeadCrossAttention(64, 8)\r\n    res=model(q, encoder_output)\r\n    print(res.shape)\r\n\r\n    '''\r\n    self-attention\r\n    '''\r\n    x=torch.rand(2, 192, 64)\r\n    # print(x.transpose(-1, -2).shape)\r\n    model=MultiHeadSelfAttention(64, 8)\r\n    res=model(x)\r\n    print(res.shape)\n<\/code><\/pre>\n\n\n\n<h3 class=\"wp-block-heading\">Tranformer Encoder<\/h3>\n\n\n\n<figure class=\"wp-block-image size-full is-resized\"><div class='fancybox-wrapper lazyload-container-unload' data-fancybox='post-images' href='https:\/\/www.vcoco.top\/wp-content\/uploads\/2025\/05\/image-8.png'><img class=\"lazyload lazyload-style-1\" src=\"data:image\/svg+xml;base64,PCEtLUFyZ29uTG9hZGluZy0tPgo8c3ZnIHdpZHRoPSIxIiBoZWlnaHQ9IjEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgc3Ryb2tlPSIjZmZmZmZmMDAiPjxnPjwvZz4KPC9zdmc+\"  loading=\"lazy\" decoding=\"async\" width=\"741\" height=\"832\" data-original=\"https:\/\/www.vcoco.top\/wp-content\/uploads\/2025\/05\/image-8.png\" src=\"data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsQAAA7EAZUrDhsAAAANSURBVBhXYzh8+PB\/AAffA0nNPuCLAAAAAElFTkSuQmCC\" alt=\"\" class=\"wp-image-546\" style=\"aspect-ratio:0.890625;width:398px;height:auto\"  sizes=\"auto, (max-width: 741px) 100vw, 741px\" \/><\/div><\/figure>\n\n\n\n<h4 class=\"wp-block-heading\">Position Embedding &amp; Token Embedding <\/h4>\n\n\n\n<p>\u4f20\u7edf\u4f4d\u7f6e\u7f16\u7801\uff1a<\/p>\n\n\n\n<div class=\"wp-block-group is-vertical is-layout-flex wp-container-core-group-is-layout-8cf370e7 wp-block-group-is-layout-flex\">\n<p>$$ \\mathbf{PE}<em>{(pos, 2i)} = \\sin\\left(\\frac{pos}{10000^{\\frac{2i}{d<\/em>{\\text{model}}}}}\\right) $$<\/p>\n\n\n\n<p>\u5176\u4e2d\uff0c$-&nbsp;pos$&nbsp;\u8868\u793a\u4f4d\u7f6e\u7d22\u5f15\u3002 $-&nbsp;i$&nbsp;\u8868\u793a\u7ef4\u5ea6\u7d22\u5f15\u3002 $-&nbsp;d_model$&nbsp;\u8868\u793a\u5d4c\u5165\u7ef4\u5ea6\u7684\u5927\u5c0f\u3002<\/p>\n<\/div>\n\n\n\n<pre class=\"wp-block-code\"><code>from torch import nn\r\nimport torch\r\nimport math\r\nfrom transformers import AutoTokenizer, AutoModelForMaskedLM\r\n\r\n\r\n\r\nclass TokenEmbedding(nn.Module):\r\n    def __init__(self, vocab_size, hidden_size):\r\n        super().__init__()\r\n        self.embedding = nn.Embedding(vocab_size, hidden_size)  # \u5d4c\u5165\u5c42\r\n    \r\n    def forward(self, x):\r\n        # x \u5f62\u72b6: (batch_size, seq_len)\r\n        embedded = self.embedding(x)  # \u5d4c\u5165\u540e\u7684\u5f62\u72b6: (batch_size, seq_len, hidden_size)\r\n        return embedded\r\n \r\nclass PositionalEmbedding(nn.Module):\r\n    def __init__(self, max_len, hidden_size):\r\n        super().__init__()\r\n        self.hidden_size = hidden_size\r\n        \r\n        # \u521b\u5efa\u4f4d\u7f6e\u7f16\u7801\u8868\uff0c\u5927\u5c0f\u4e3a (max_len, hidden_size)\r\n        # position: (max_len, 1)\uff0c\u8868\u793a\u5e8f\u5217\u4e2d\u7684\u4f4d\u7f6e\u7d22\u5f15\uff0c\u4f8b\u5982 &#91;&#91;0.], &#91;1.], &#91;2.], ...]\r\n        position = torch.arange(0, max_len).unsqueeze(1).float()\r\n        \r\n        # div_term: (hidden_size \/ 2)\uff0c\u7528\u4e8e\u8ba1\u7b97\u4f4d\u7f6e\u7f16\u7801\u7684\u5206\u6bcd\r\n        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) \/ hidden_size))\r\n        \r\n        # \u521d\u59cb\u5316\u4f4d\u7f6e\u7f16\u7801\u77e9\u9635 pe \u4e3a\u96f6\u77e9\u9635\uff0c\u5927\u5c0f\u4e3a (max_len, hidden_size)\r\n        pe = torch.zeros(max_len, hidden_size)\r\n        \r\n        # \u8ba1\u7b97\u4f4d\u7f6e\u7f16\u7801\u77e9\u9635\uff0c\u5e7f\u64ad\u673a\u5236\u5c06 dive_term \u6269\u5c55\u4e3a (1, hidden_size )\r\n        # \u5076\u6570\u7d22\u5f15\u5217\u4f7f\u7528 sin \u51fd\u6570\r\n        pe&#91;:, 0::2] = torch.sin(position * div_term)\r\n        # \u5947\u6570\u7d22\u5f15\u5217\u4f7f\u7528 cos \u51fd\u6570\r\n        pe&#91;:, 1::2] = torch.cos(position * div_term)\r\n        \r\n        # \u5c06\u4f4d\u7f6e\u7f16\u7801\u77e9\u9635\u6ce8\u518c\u4e3a buffer\uff0c\u6a21\u578b\u8bad\u7ec3\u65f6\u4e0d\u4f1a\u66f4\u65b0\u5b83\r\n        self.register_buffer('pe', pe)\r\n    \r\n    def forward(self, x):\r\n        # x \u7684\u5f62\u72b6: (batch_size, seq_len, hidden_size)\r\n        seq_len = x.size(1)\r\n        \r\n        # \u5c06\u4f4d\u7f6e\u7f16\u7801\u52a0\u5230\u8f93\u5165\u5f20\u91cf\u4e0a\r\n        # self.pe&#91;:seq_len, :] \u7684\u5f62\u72b6\u4e3a (seq_len, hidden_size)\r\n        # unsqueeze(0) \u4f7f\u5176\u5f62\u72b6\u53d8\u4e3a (1, seq_len, hidden_size)\uff0c\u4fbf\u4e8e\u4e0e\u8f93\u5165\u5f20\u91cf\u76f8\u52a0\r\n        x = x + self.pe&#91;:seq_len, :].unsqueeze(0)\r\n        \r\n        # \u8fd4\u56de\u52a0\u4e0a\u4f4d\u7f6e\u7f16\u7801\u540e\u7684\u5f20\u91cf\r\n        return x\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    # 1) \u53d6\u4e00\u4e2a\u771f\u6b63\u5927\u6a21\u578b\u7684\u8bcd\u8868\uff08BERT base\uff09\r\n    # Load model directly\r\n    tokenizer = AutoTokenizer.from_pretrained(\"google-bert\/bert-base-uncased\")\r\n    # model = AutoModelForMaskedLM.from_pretrained(\"google-bert\/bert-base-uncased\")\r\n\r\n    vocab_size = tokenizer.vocab_size          # 30522\r\n    hidden_size = 768                          # \u8ddf bert-base-uncased \u4e00\u81f4\r\n\r\n    # 2) \u5b9e\u4f8b\u5316\u81ea\u5b9a\u4e49\u7684 TokenEmbedding\r\n    model = TokenEmbedding(vocab_size, hidden_size)\r\n\r\n    # 3) \u51c6\u5907\u82e5\u5e72\u53e5\u5b50\r\n    sentences = &#91;\r\n        \"Transformers are amazing!\",\r\n        \"Let's test our token embedding layer.\"\r\n    ]\r\n\r\n    # 4) \u4f7f\u7528 tokenizer \u8f6c\u6210 token ids\uff1b\u540c\u65f6\u7edf\u4e00 sequence length\uff08padding\uff09\r\n    encoded = tokenizer(\r\n        sentences,\r\n        padding=True,\r\n        truncation=True,\r\n        return_tensors=\"pt\"\r\n    )\r\n\r\n    input_ids = encoded&#91;\"input_ids\"]           # shape: (batch_size, seq_len)\r\n\r\n    print(\"Token IDs:\\n\", input_ids)\r\n    print(\"Shape:\", input_ids.shape)\r\n\r\n    # 5) \u5582\u8fdb\u5d4c\u5165\u5c42\r\n    with torch.no_grad():\r\n        embedded = model(input_ids)            # (batch_size, seq_len, hidden_size)\r\n\r\n    print(\"Embedded Tensor Shape:\", embedded.shape, '\\n')\r\n    \r\n    embedding_table = model.embedding.weight           # (vocab_size, hidden_size)\r\n\r\n    # -------- \u4ece embedded \u53cd\u63a8\u51fa token id --------------\r\n    # \u628a (batch, seq_len, hidden) \u62c9\u5e73\u6210 (batch*seq_len, hidden)\r\n    flat_emb = embedded.view(-1, hidden_size)          # (B*S, H)\r\n\r\n    # \u70b9\u79ef\u5f97\u5230\u76f8\u4f3c\u5ea6\uff0c\u8d8a\u5927\u8d8a\u76f8\u4f3c\r\n    sim = torch.matmul(flat_emb, embedding_table.t())  # (B*S, V)\r\n    print('embedding_table.shape:', embedding_table.shape)\r\n\r\n    # \u5bf9 vocab \u7ef4\u5ea6\u53d6 argmax \u5f97\u5230\u6700\u76f8\u4f3c\u7684 token id\r\n    pred_ids = torch.argmax(sim, dim=-1)               # (B*S, )\r\n\r\n    # reshape \u56de\u539f\u6765\u7684 (batch, seq_len)\r\n    pred_ids = pred_ids.view(embedded.size(0), embedded.size(1))\r\n\r\n    print(\"\\nRecovered ids:\")\r\n    print(pred_ids)\r\n\r\n    # -------- 2. \u89e3\u7801\u6210\u5b57\u7b26\u4e32 -----------------------------\r\n    decoded_text = tokenizer.batch_decode(\r\n        pred_ids,\r\n        skip_special_tokens=False,    # \u4e3a\u4e86\u5bf9\u7167\uff0c\u5148\u4e0d\u53bb\u6389 &#91;PAD]\/&#91;CLS]\/&#91;SEP]\r\n        clean_up_tokenization_spaces=True\r\n    )\r\n\r\n    print('original text')\r\n    print(sentences)\r\n    \r\n    print(\"\\nDecoded text:\")\r\n    for i, sent in enumerate(decoded_text):\r\n        print(f\"{i}: {sent}\")\r\n\r\n    \r\n    pos_emb_model = PositionalEmbedding(vocab_size, hidden_size)\r\n    pos_emb_embedded = pos_emb_model(embedded)\r\n    print('\\npos_emb_embedded.shape:', pos_emb_embedded.shape)\r\n    \r\n\n<\/code><\/pre>\n\n\n\n<h4 class=\"wp-block-heading\">Encoder Layer<\/h4>\n\n\n\n<pre class=\"wp-block-code\"><code>class EncoderLayer(nn.Module):\n    def __init__(self, hidden_size, num_heads, ff_size, dropout_prob=0.1):\n        super().__init__()\n        self.multi_head_attention = MultiHeadSelfAttention(hidden_size, num_heads)  # \u591a\u5934\u6ce8\u610f\u529b\u5c42\n        self.dropout1 = nn.Dropout(dropout_prob)  # Dropout \u5c42\n        self.layer_norm1 = nn.LayerNorm(hidden_size)  # LayerNorm \u5c42\n\n        self.feed_forward = nn.Sequential(\n            nn.Linear(hidden_size, ff_size),  # \u524d\u9988\u5c421\n            nn.ReLU(),  # \u6fc0\u6d3b\u51fd\u6570\n            nn.Linear(ff_size, hidden_size)  # \u524d\u9988\u5c422\n        )\n        self.dropout2 = nn.Dropout(dropout_prob)  # Dropout \u5c42\n        self.layer_norm2 = nn.LayerNorm(hidden_size)  # LayerNorm \u5c42\n    \n    def forward(self, x, attention_mask=None):\n        # \u591a\u5934\u6ce8\u610f\u529b\u5b50\u5c42\n        attn_output = self.multi_head_attention(x, attention_mask)  # (batch_size, seq_len, hidden_size)\n        attn_output = self.dropout1(attn_output)  # Dropout\n        out1 = self.layer_norm1(x + attn_output)  # \u6b8b\u5dee\u8fde\u63a5 + LayerNorm\n        \n        # \u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42\n        ff_output = self.feed_forward(out1)  # (batch_size, seq_len, hidden_size)\n        ff_output = self.dropout2(ff_output)  # Dropout\n        out2 = self.layer_norm2(out1 + ff_output)  # \u6b8b\u5dee\u8fde\u63a5 + LayerNorm\n        \n        return out2\n\nx=torch.rand(2, 192, 64)\nmodel=EncoderLayer(64, 8, 32)\nmodel(x).shape<\/code><\/pre>\n\n\n\n<h3 class=\"wp-block-heading\">Transformer Decoder<\/h3>\n\n\n\n<p>Transformer\u7684Decoder\u5c42\u4e0eEncoder\u5927\u4f53\u7c7b\u4f3c\uff0c\u9664\u4e86\u5305\u542b\u591a\u5934\u81ea\u6ce8\u610f\u529b\u673a\u5236\u548c\u524d\u9988\u795e\u7ecf\u7f51\u7edc\uff0c\u8fd8\u589e\u52a0\u4e86\u4e00\u4e2a\u7528\u4e8e\u7f16\u7801\u5668-\u89e3\u7801\u5668\u6ce8\u610f\u529b\u673a\u5236\u7684\u591a\u5934\u6ce8\u610f\u529b\u5b50\u5c42\u3002\u8fd9\u4f7f\u5f97 Decoder \u5c42\u80fd\u591f\u540c\u65f6\u5173\u6ce8\u5f53\u524d\u8f93\u51fa\u5e8f\u5217\u7684\u4e0a\u4e0b\u6587\u4fe1\u606f\u548c\u8f93\u5165\u5e8f\u5217\u7684\u7f16\u7801\u4fe1\u606f\u3002<\/p>\n\n\n\n<p><strong>CrossAttention<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch.nn as nn\nimport torch\nclass MultiHeadCrossAttention(nn.Module):\n    def __init__(self, hidden_size, head_nums):\n        super().__init__()\n        self.Q=nn.Linear(hidden_size, hidden_size)\n        self.K=nn.Linear(hidden_size, hidden_size)\n        self.V=nn.Linear(hidden_size, hidden_size)\n        self.linear=nn.Linear(hidden_size, hidden_size)\n        self.head_nums=head_nums\n    def forward(self, q, key_value, causal_mask=None, pad_mask=None):\n        (bs, N, hd) = q.shape\n        hd\/=self.head_nums\n        q=self.Q(q)\n        k=self.K(key_value)\n        v=self.V(key_value) # (bs, len, dim)\n\n        # (bs, len, head_nums, dim\/head_nums) -&gt; (bs,head_nums,len,dim\/head_nums)\n        q=q.reshape(bs, N, self.head_nums, -1).transpose(1,2)\n        k=q.reshape(bs, N, self.head_nums, -1).transpose(1,2)\n        v=q.reshape(bs, N, self.head_nums, -1).transpose(1,2)\n\n        # (len, dim) @ (dim, len) = (len, len)\n        qk = q @ k.transpose(-1, -2) \/ (hd**0.5)\n\n        if causal_mask:\n            qk = qk * causal_mask\n        if pad_mask:\n            qk = qk * pad_mask\n\n        qk = torch.softmax(qk, dim=-1)\n\n        res = qk @ v  #(bs,head_nums,len,dim\/head_nums)\n        res = res.transpose(1,2).reshape(bs, N, -1)\n\n        res=self.linear(res)\n\n        return res\n\n\nq=torch.rand(2, 192, 64)\nencoder_output=torch.rand(2, 192, 64)\n# print(x.transpose(-1, -2).shape)\nmodel=MultiHeadCrossAttention(64, 8)\nres=model(q, encoder_output)\nres.shape\n<\/code><\/pre>\n\n\n\n<p><strong>Decoder Layer<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class DecoderLayer(nn.Module):\n    def __init__(self, hidden_size, num_heads, ff_size, dropout_prob=0.1):\n        super().__init__()\n        self.multi_head_attention = MultiHeadSelfAttention(hidden_size, num_heads)  # \u591a\u5934\u6ce8\u610f\u529b\u5c42\n        self.dropout1 = nn.Dropout(dropout_prob)  # Dropout \u5c42\n        self.layer_norm1 = nn.LayerNorm(hidden_size)  # LayerNorm \u5c42\n\n        self.multi_head_attention2 = MultiHeadCrossAttention(hidden_size, num_heads)  # \u591a\u5934\u6ce8\u610f\u529b\u5c42\n        self.dropout2 = nn.Dropout(dropout_prob)  # Dropout \u5c42\n        self.layer_norm2 = nn.LayerNorm(hidden_size)  # LayerNorm \u5c42\n\n        self.feed_forward = nn.Sequential(\n            nn.Linear(hidden_size, ff_size),  # \u524d\u9988\u5c421\n            nn.ReLU(),  # \u6fc0\u6d3b\u51fd\u6570\n            nn.Linear(ff_size, hidden_size)  # \u524d\u9988\u5c422\n        )\n        self.dropout3 = nn.Dropout(dropout_prob)  # Dropout \u5c42\n        self.layer_norm3 = nn.LayerNorm(hidden_size)  # LayerNorm \u5c42\n    \n    def forward(self, x, encoder_output, attention_mask=None):\n        # \u591a\u5934\u6ce8\u610f\u529b\u5b50\u5c42\n        attn_output = self.multi_head_attention(x, attention_mask)  # (batch_size, seq_len, hidden_size)\n        attn_output = self.dropout1(attn_output)  # Dropout\n        out1 = self.layer_norm1(x + attn_output)  # \u6b8b\u5dee\u8fde\u63a5 + LayerNorm\n\n        cross_out=self.multi_head_attention2(out1, encoder_output)\n        cross_attn_output = self.dropout1(cross_out)  # Dropout\n        out2 = self.layer_norm2(out1 + cross_attn_output)  # \u6b8b\u5dee\u8fde\u63a5 + LayerNorm\n        \n        # \u524d\u9988\u795e\u7ecf\u7f51\u7edc\u5b50\u5c42\n        ff_output = self.feed_forward(out2)  # (batch_size, seq_len, hidden_size)\n        ff_output = self.dropout2(ff_output)  # Dropout\n        out3 = self.layer_norm3(out2 + ff_output)  # \u6b8b\u5dee\u8fde\u63a5 + LayerNorm\n        \n        return out2\n\nx=torch.rand(2, 192, 64)\nencoder_output=torch.rand(2, 192, 64)\nmodel=DecoderLayer(64, 8, 32)\nmodel(x, encoder_output).shape<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>\u5355\u5934\u6ce8\u610f\u529b\u673a\u5236 Q, K, V\u7684\u7406\u89e3\u3002 \u5047\u8bbeQ\u7531\u5982\u4e0b\u77e9\u9635\u7ec4\u6210\uff08\u4e0d\u8003\u8651batchsize\uff09\uff1aQ[0]\uff1a\u8bcd1\uff1a[d [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[83],"tags":[103,104],"class_list":["post-526","post","type-post","status-publish","format-standard","hentry","category-deep-learning","tag-attention","tag-transformer"],"_links":{"self":[{"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/posts\/526","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/comments?post=526"}],"version-history":[{"count":14,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/posts\/526\/revisions"}],"predecessor-version":[{"id":662,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/posts\/526\/revisions\/662"}],"wp:attachment":[{"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/media?parent=526"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/categories?post=526"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.vcoco.top\/index.php\/wp-json\/wp\/v2\/tags?post=526"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}