Begin simplifying CrossAttention so that it works better on the Apple Neural Engine#691
Begin simplifying CrossAttention so that it works better on the Apple Neural Engine#691MatthewWaller wants to merge 8 commits intohuggingface:mainfrom
Conversation
Update repo with main
|
The documentation is not available anymore as the PR was closed or merged. |
| hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
| return hidden_states | ||
| batch_size, sequence_length, heads, last_dim = query.shape | ||
| attn = torch.einsum("bjhd,bihd->bhji", query, key) |
There was a problem hiding this comment.
I think we moved away from einsum for speed and ONNX-compatibility cc @NouamaneTazi @anton-l no?
There was a problem hiding this comment.
In my experiments, einsum is equivalent to matmul in terms for speed, and both support jitting.
I believe we moved away from it because of some MPS compatibility issues. cc @pcuenca
|
Yeah, this is going to take more investigation. More experimenting has revealed that this may not be the exact pain point for ANE. I know that einsum can cause problems for certain types. Only two versions were natively supported by coremltools for instance. This one is one of the ones that should work no problem. But since I haven't been able to fully diagnose where the hangup is, I'll put this PR on ice. |
Hi folks,
This is to address this issue.
I converted this CrossAttention portion with coremltools, and it does in fact remove about 4 reshape operation and a few transposes, getting down to, 4 transposes and 4 reshapes left.
Unfortunately, it seems that is still too many to compile on the ANE.
Any ideas about what else I could do to simplify this? I took a stab at using another einsum for the attn and value matmul, but I don't think I was doing it correctly.