Skip to content

Indexes of nearest neighbor #13

@alex-kharlamov

Description

@alex-kharlamov

This code

N = 7
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nn = N3AggregationBase(5, temp_opt={"external_temp": False})
nn.cuda()
nn.nnn.log_temp_bias = -50 # decrease temperature -> NNN acts more like hard kNN

# x = torch.tensor(np.random.permutation(list(range(N))), dtype=torch.float, requires_grad=True)
x = torch.tensor(list(range(N)), dtype=torch.float, requires_grad=True)
n = torch.zeros_like(x).normal_() * 0.0001
x = x+n
x = x.reshape(1, N, 1).to(device)
xe = x
ye = xe
I = torch.tensor(list(range(N)), dtype=torch.long).repeat(N, 1).reshape(1, N, N).to(device)

z = nn(x, xe, ye, I)

produces output z with shape torch.Size([1, 7, 1, 5]).
If input embeddings have more than 1 feature, third dimension of output would be changed.

How to get global indexes of nearest neighbors?
For example, if we have XE with shape [1, 10, 5], and YE with shape [1, 3, 5], i want to have output indexes with shape [1, 3], just as indexes of nearest neighbors in KNeighborsClassifier.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions