-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KV Cache 的实现为什么xq要拼接 zerors 矩阵 #18
Comments
llama3在推理时候的seqlen是1,它的generate函数每次只把current_token输入attention层计算。 # llama3 attention
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
# 推理时seqlen==1
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 到这里为止,xq, xk, xv的shape都是[bsz, 1, *, self.head_dim]
# freqs_cis的输入也是cis[-1:, :] = [1, head_dim]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 在计算RoPE嵌入之前,q,k,v维度需要一致
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
在计算时,仅使用当前 token 作为 与 LLaMA3 不同的是, 当只计算当前 token 的 的确,前者注意力的计算复杂度为 刚刚看着改了一下,这是在现有方案上强行实现 def forward(
self,
x: torch.Tensor,
pos_cis: torch.Tensor,
use_kv_cache: bool = False,
past_kv: Tuple[torch.Tensor] = None
):
bsz, seqlen, _ = x.shape
keys, values = None, None
flag = 1
# QKV
# inference
if use_kv_cache:
current_token = x[:, -1:, :]
if not past_kv:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
flag = 1
past_kv = (xk, xv)
else:
past_key, past_value = past_kv
xq = self.wq(current_token)
xk = self.wk(current_token)
xv = self.wv(current_token)
keys = torch.cat((past_key, xk), dim=1)
values = torch.cat((past_value, xv), dim=1)
past_kv = (keys, values)
flag = 2
else:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
if flag == 2:
xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
else:
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
if flag == 1:
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
else:
xq, xk = apply_rotary_emb(xq, xk, pos_cis[-1:, :])
if flag == 2:
past_key, past_value = past_kv
keys = torch.cat((past_key[:, :-1, :], xk.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
values = torch.cat((past_value[:, :-1, :], xv.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
past_kv = (keys, values)
keys = keys.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
values = values.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xk = keys
xv = values
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
if flag == 2:
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return output, past_kv 暂时推理函数就先不做大改了。 欢迎继续交流指正。 |
在 Attention 方法里涉及到 KV cache 的实现部分
为什么 xq 需要拼接 zeros 矩阵?
是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行;
例如 llama3 的实现就没有拼接 zeros 矩阵:
The text was updated successfully, but these errors were encountered: