Skip to content

Begin simplifying CrossAttention so that it works better on the Apple Neural Engine#691

Closed
MatthewWaller wants to merge 8 commits intohuggingface:mainfrom
MatthewWaller:main
Closed

Begin simplifying CrossAttention so that it works better on the Apple Neural Engine#691
MatthewWaller wants to merge 8 commits intohuggingface:mainfrom
MatthewWaller:main

Conversation

@MatthewWaller
Copy link

Hi folks,

This is to address this issue.

I converted this CrossAttention portion with coremltools, and it does in fact remove about 4 reshape operation and a few transposes, getting down to, 4 transposes and 4 reshapes left.

Unfortunately, it seems that is still too many to compile on the ANE.

Any ideas about what else I could do to simplify this? I took a stab at using another einsum for the attn and value matmul, but I don't think I was doing it correctly.

@MatthewWaller
Copy link
Author

cc: @patrickvonplaten @pcuenca

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 30, 2022

The documentation is not available anymore as the PR was closed or merged.

hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
batch_size, sequence_length, heads, last_dim = query.shape
attn = torch.einsum("bjhd,bihd->bhji", query, key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we moved away from einsum for speed and ONNX-compatibility cc @NouamaneTazi @anton-l no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experiments, einsum is equivalent to matmul in terms for speed, and both support jitting.
I believe we moved away from it because of some MPS compatibility issues. cc @pcuenca

@MatthewWaller
Copy link
Author

Yeah, this is going to take more investigation. More experimenting has revealed that this may not be the exact pain point for ANE.

I know that einsum can cause problems for certain types. Only two versions were natively supported by coremltools for instance. This one is one of the ones that should work no problem.

But since I haven't been able to fully diagnose where the hangup is, I'll put this PR on ice.

PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants