轩阳的赛博居所

Sing once again with me!

0%

Frame Averaging 论文阅读

文献阅读 等变图神经网络 Frame Averaging

ICLR 2022 ORAL


目录:

  1. 概念说明和理论证明
  2. 实验细节设置、图的背景知识和代码分析

补充知识:

  1. WL算法判断图同构
  2. GIN网络
  3. message-passing-GNN

参考资料:

【GNN】WL-test:GNN 的性能上界 (qq.com)

KNN算法(k近邻算法)原理及总结-CSDN博客

搞懂DGCNN,这篇就够了!论文及代码完全解析 - 知乎 (zhihu.com)


相关文献:

SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks

Vector Neurons: A General Framework for SO(3)-Equivariant Networks(矢量神经元?)

On the Universality of Rotation Equivariant Point Cloud Networks ICLR2020

GemNet: Universal Directional Graph Neural Networks for Molecules NIPS2021

Equiformer: Equivariant Graph Attention Transformer for 3D Atomistic Graphs ICLR2022


  • 等变性:称函数在群空间G上具有等变性,若

    其中,是输入和输出空间的群表示。

  • 群表示:把一个群通过群同态,映射到一个可逆线性变换(可逆矩阵)中。

  • 赋范线性空间:定义了范数的线性空间。赋范线性空间的性质比距离空间强很多,因为它上面一定可以定义距离

映射。考虑一个群,通过群表示作用于V,W上。

目标:让成为“不变的”函数,即;让成为等变的函数,即。这里之所以一个是映射到R,一个是映射到W,是因为前者尝尝对应于分类问题,后者则经常对应于点云处理等,需要输出以(和输入)相同方式旋转的问题。

之前的研究说明,通过数据在整个G上进行变换,然后取

就可以得到不变或者等变的网络。然而,FA只用了G的一个子集。

  • Def1. 一个frame被定义为一个取值为集合的函数,将输入空间V映射到G的子集的集合,也就是每个输入空间的元素x映到一个G的子集。其具有两个性质(注意,下面的F(X)是一个集合,里面的每个元素都是G中的元素)
    • G-equivariant 。其中gF(X)就是g对F(x)中每个元素进行左乘后得到的元素集合,上式等号指代集合意义的相等。
    • 有界性 通过X定义的frame,通过frame里的元素g定义的ρ(g),ρ(g)是有界的

因此可以用这个G的子集,来定义我们的平均化操作

  • Theo1.这样得到的frame是具有G-不变/等变的。也就是只用这个F(x)进行的映射构造,和利用整个G进行构造,效果相同。

不变性的证明:
$$
\text{prove of: }<\phi>F(\rho_1(g’)X) = <\phi>F(X)\
\text{ 即证明利用帧平均得到的映射,对于G中每一个元素都是不变的。}
\forall g’\in G, \
\begin{aligned}<\phi>F(\rho_1(g’)X)=&\frac{1}{F(X)}\sum{g\in g’F(X)}\phi(\rho_1(g)^{-1}\rho_1(g’)X) \quad &\text{(定义)}\
=&\frac{1}{F(X)}\sum
{g\in F(X)}\phi(\rho_1(g’g)^{-1}\rho_1(g’)X) \quad &\
=&\frac{1}{F(X)}\sum
{g\in F(X)}\phi(\rho_1(g)^{-1}X)\quad &\text{(群同态保持乘法)}\
=&<\phi>_F(X)&
\end{aligned}
$$
第二个等号是因为,F(X)中每个元素是g,让其中的每个元素g左乘g’,就相当于一个g’F(X)的元素。第三个等号是因为群同态保持乘法,而g’g的逆是g逆g’逆,因为逆元唯一。

等变性的证明是类似的:

不变映射,是等变映射在选择 W = R 和平凡表示 ρ2(g) ≡ 1 时的特例。

进一步地,如果某个网络架构已经对于某个对称群H具有了不变性或等变性,则可以将这种不变性或等变性扩展到更大的对称群H×G。

预定义:设是G和H在V上的表示,同理是GH在W上的表示。

  1. 称表示是可交换的,如果对于所有的g,h,X(属于V)有。如果他们可交换,则可以定义GH上的一个表示
  2. F(X)对于H来说是不变的,即
  • Theo2. 设F是H-不变,G-等变的:
    • 如果是H-不变的,表示是可交换的,那么是G×H不变的
    • 如果是H-等变的,表示是可交换的,那么是G×H等变的

也可以通过右乘来定义等变性/不变性。

更加高效地计算不变性frame:

稳定子群:作用于X上而保持X不变的元素。

等价关系:如果中两个元素可以通过中的元素相互转换,,则它们是等价的。所有和g等价的F(X)中元素记为等价类,也就是一个轨道[g]。

  • Theo3. 是所有等大小轨道[g]的无交并,

说明这个F(X)大小至少是X处稳定子群的大小。另外由于[g]轨道上每个元素对于X作用后,通过会产生一样的效果:

因此只需要每个轨道挑一个出来计算/构造不变映射即可。(如何挑出不同的轨道?这个过程是否仍然需要计算?还是说依赖于具体的群的类型)

  • 推论:是个等变的,是均匀的,那么轨道也是均匀的

好像下面这个推论是对上面的问题的回答。既然每条轨道大小相等,那么我在所有轨道组成的并集也就是整个F(X)上随机取元素,取出的元素在每个轨道上的概率相等。

这个过程中唯一的问题就是需要计算出轨道的数量,我推测这里只要能算出frame的大小|F(X)|和稳定子群的大小就行了。可以用抽样方法估计稳定子群的大小,即随机找一些g看他们作用于X之后X是否保持不变。

运用了FA的映射,表达能力得到了保证。

  • Theo4. 一个运用了FA方法(假设其在K上是有界的,即K中的每个X,F(X)有界)的网络/映射(记为映射1),其与任意一个G-等变网络/映射(记为映射2)的结果的差距,可以被两个映射在上的上界一致控制。(映射1运用了FA之后也变成G-等变了)

    这个证明很有意思

    A.5 PROOF OF THEOREM 4

Let be an arbitrary equivariant function, a bounded equivariant frame

over a frame-finite domain Let be the constant from Definition 1. For arbitrary ,
$$
\begin{aligned}\left|\Psi(X)-\langle\Phi\rangle_{\mathcal{F}}(X)\right|{W}&=\left|\langle\Psi\rangle{\mathcal{F}}(X)-\langle\Phi\rangle_{\mathcal{F}}(X)\right|{W}\&\leq\frac{1}{|\mathcal{F}(X)|}\sum{g\in\mathcal{F}(X)}\left|\rho_{2}(g)\Psi(\rho_{1}(g)^{-1}X)-\rho_{2}(g)\Phi(\rho_{1}(g)^{-1}X)\right|{W}\&\leq\max{g\in\mathcal{F}(X)}\left|\rho_{2}(g)\right|{\mathrm{op}}\left|\Psi-\Phi\right|{K_{F},W}\&\leq c\left|\Psi-\Phi\right|{K{\mathcal{F},W}}\end{aligned}
$$

where in the first equality we used the fact that since is already G equivariant.

实验部分:

重点看了后两个实验的设置和代码

  • 3D点云-欧几里得变换群

群:rotation & reflection 群O(d) x translation 群T(d),或者SE(d)=SO(d)xT(d),其中SO(d)只包含旋转

Frame:通过PCA定义。

这里假设协方差矩阵具有simple spectrum,也就是非重复特征值,“几乎处处”可以定义F(X)

F也是Sn不变的。

  • 图(graph)-置换群Sn

这一块是最难看懂的,涉及对图的谱表示,WL可区分性及GIN对WL的性能逼近等

背景知识补充1:GIN网络和图同构问题的算法性能上界

对图同构问题(NP问题)目前最有效的算法是Weisfeiler-Lehman 算法,可以在准多项式时间求解。WL算法解决的问题是:比较两个图是否同构。概括来说,每次迭代中,将一个点表示成**{这个点的编号:周围的点的编号}**的形式(这里的“编号”相当于点的种类),然后为每一种表示构建唯一的hash编码。迭代多轮或者图收敛后,停止迭代。

WL算法和图网络类似:每轮都是聚合+更新。

WL-test算法是GNN的性能上界。因为可以证明:如果能用GNN区分两个图是否同构,则WL-test一定可以。证明要点在于WL-test算法的单射性质

构造出一个图网络,逼近这个上界。构造出的网络就是GIN。论文中定理指出了,图神经网络在满足这些条件时,能够将WL测试认为是非同构图的两个图和映射到不同的embedding

①按照进行每轮的aggregate;

②作用在图级的READOUT函数也是单射的

后面涉及了一些关于可数集函数映射和利用MLP逼近表示的定理,没有细看。总之最后得到的每层AGGREGATE更新方式是:

READOUT为对每次迭代得到的所有节点的特征求和得到该轮迭代的图特征,然后再拼接起每一轮迭代的图特征来得到最终的图特征

相比其他图网络的单层传播,多层感知机拟合能力强;相比mean和max,sum不会混淆一些结构。

代码说明:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class GinNet(nn.Module):
def __init__(self):
super(GinNet, self).__init__()
neuron=64
r1=np.random.uniform()
r2=np.random.uniform()
r3=np.random.uniform()

nn1 = Sequential(Linear(dataset.num_features, neuron))
self.conv1 = GINConv(nn1,eps=r1,train_eps=True)

nn2 = Sequential(Linear(neuron, neuron))
self.conv2 = GINConv(nn2,eps=r2,train_eps=True)

nn3 = Sequential(Linear(neuron, neuron))
self.conv3 = GINConv(nn3,eps=r3,train_eps=True)

self.fc1 = torch.nn.Linear(neuron, 10)


def forward(self, data):

x=data.x
edge_index=data.edge_index

x = torch.tanh(self.conv1(x, edge_index))
x = torch.tanh(self.conv2(x, edge_index))
x = torch.tanh(self.conv3(x, edge_index))

x = global_add_pool(x, data.batch)
x = torch.tanh(self.fc1(x))
return x
1
2
3
4
5
6
7
8
9
10
# 索引数组的随机置换,相当于Sn里面元素的作用
def genrate_perm(perm_idx):
perm = [np.random.permutation(perm) for perm in perm_idx]
return perm

# 根据边索引构造一个n x n的邻接矩阵,然后将该矩阵展平输出
def generate_A(edge_index, n):
A=np.zeros((n,n),dtype=np.float32)
A[edge_index[0],edge_index[1]]=1
return torch.from_numpy(A.flatten())

sort_fn_laplacian的作用是根据拉普拉斯矩阵的特征值,返回节点排序索引和对应的特征向量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def sort_fn_laplacian(x,edge_index):
# 构建拉普拉斯矩阵
L_e, L_w = torch_geometric.utils.get_laplacian(edge_index)
L = np.zeros((x.shape[0],x.shape[0]),dtype=np.float32)
L[L_e[0],L_e[1]]=L_w

# 计算特征值和特征向量,按照特征值从大到小返回
evals, evecs = np.linalg.eigh(L)
# ----- create sorting criterion -----
# 使用unique函数找出所有唯一的特征值,以及每个特征值的起始索引和重数
unique_vals, evals_idx, evals_mult = np.unique(evals, return_counts=True, return_index=True) # get eigenvals multiplicity

chosen_evecs = []

# 如果特征值不重复,就取出特征向量的绝对值,加入chosen_evecs
# 如果特征值重复,计算子空间基向量的平方和的平方根,加入chosen_evecs
# 注意到,即使是重复特征值(重复了k次比如说),也会计算出k个向量加入chosen_evecs
for ii in range(len(evals_idx)):
if evals_mult[ii] == 1:
chosen_evecs.append(np.abs(evecs[:,evals_idx[ii]]))
else:
eigen_space_start_idx = evals_idx[ii]
eigen_space_size = evals_mult[ii]
eig_space_basis = evecs[:, eigen_space_start_idx:(eigen_space_start_idx+eigen_space_size)]
chosen_evecs.append(np.sqrt((eig_space_basis ** 2).sum(1)))

# 将选择的特征向量堆叠成一个矩阵,保留两位小数
chosen_evecs = np.stack(chosen_evecs, axis=1).round(decimals=2)
# lexsort根据键排序,返回排序索引sort_idx和chosen_evecs。[::-1]是倒序取所有的索引
sort_idx = np.lexsort([col for col in chosen_evecs.transpose()[::-1]]) # consider regular sort
return sort_idx, chosen_evecs

SortFrame 类是用于处理和转换图数据,按照基于拉普拉斯矩阵的特征向量对图中的节点进行排序。传入的data首先需要经过pre_transform,然后进行结点和边的重排序(使用sort_fn_laplacian),返回排序后的data。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class SortFrame(object):
def __init__(self,device, pre_transform,sort_fn=sort_fn_laplacian):
self.pre_transform = pre_transform
self.sort_fn = sort_fn
self.device = device

def __call__(self, data):
data = self.pre_transform(data)
# data = Data(**data.__dict__)
sort_idx, to_sort = self.sort_fn(data.x, data.edge_index)
sorted_x = to_sort[sort_idx,:]
unique_rows, dup_rows_idx, dup_rows_mult = np.unique(sorted_x, axis=0, return_index=True, return_counts=True)

perm_start_idx = dup_rows_idx[dup_rows_mult!=1]
perm_size = dup_rows_mult[dup_rows_mult!=1]
perm_idx = []
for ii in range(len(perm_size)):
perm_idx.append(np.arange(perm_start_idx[ii], perm_start_idx[ii]+perm_size[ii]))
data.perm_idx = perm_idx
data.sort_idx = sort_idx
data.size = data.x.shape[0]
return data

SampleFrame8C类:生成一组新的图数据样本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class SampleFrame8C(object):
def __init__(self,size=64,sample_size=10, GA=False,MLP=False):
self.sample_size = sample_size
self.size = size
self.counter=0
self.GA=GA
self.id = not MLP
self.MLP = MLP

def apply_permutation(self, perm_idx, edge_index, x):
inv_perm_idx = np.zeros_like(perm_idx)
inv_perm_idx[perm_idx] = np.arange(perm_idx.shape[0]) #根据perm_idx.shape[0]返回等差索引
sorted_edge_index = torch.tensor(inv_perm_idx[edge_index])
sorted_x = x[perm_idx,:]
return sorted_edge_index, sorted_x

def permute_with_perm(self,perm_idx, perm, sort_idx, edge_index, x):
inv_sort_idx = np.zeros_like(sort_idx)
inv_sort_idx[sort_idx] = np.arange(sort_idx.shape[0])
inv_sort_idx[sort_idx[list(itertools.chain(*perm_idx))]] = list(itertools.chain(*perm))
sorted_edge_index = inv_sort_idx[edge_index]
cur_sort_idx = np.zeros_like(sort_idx)
cur_sort_idx[inv_sort_idx] = np.arange(sort_idx.shape[0])

sorted_x_feat = x[cur_sort_idx,:]

return sorted_edge_index, sorted_x_feat


def __call__(self,data):
n = data.x.shape[0]
m = self.size
x = data.x
d = x.shape[1]
edge_index = data.edge_index
sort_idx = data.sort_idx
perm_idx = data.perm_idx
x_arr = []
e_arr = []
if not perm_idx:
if self.GA:
sorted_edge_index, sorted_x = self.apply_permutation(np.random.permutation(n), edge_index.clone(), x.clone())
else:
sorted_edge_index, sorted_x = self.apply_permutation(sort_idx, edge_index.clone(), x.clone())
data.edge_index = sorted_edge_index.detach()
if self.MLP:
new_x = generate_A(sorted_edge_index,m).unsqueeze(0)
data = new_x,data.y
else:
if self.id:
data.x = torch.cat([sorted_x.detach(),torch.eye(n, dtype=x.dtype),torch.zeros((n,m-n), dtype=x.dtype)],1).clone()
else:
data.x = sorted_x.detach()
data.edge_index = sorted_edge_index
else:
for i in range(self.sample_size):
if self.GA:
sorted_edge_index, sorted_x = self.apply_permutation(np.random.permutation(n), edge_index.clone(), x.clone())
else:
perm = genrate_perm(perm_idx)
sorted_edge_index, sorted_x = self.permute_with_perm(perm_idx, perm, sort_idx, edge_index.clone(), x.clone())
sorted_edge_index = torch.from_numpy(sorted_edge_index)
if self.MLP:
x_arr.append(generate_A(sorted_edge_index,m))
else:
if self.id:
x_arr.append(torch.cat([sorted_x.detach(),torch.eye(n, dtype=x.dtype),torch.zeros((n,m-n), dtype=x.dtype)],1).clone())
else:
x_arr.append(sorted_x.detach())
e_arr.append(torch.tensor(sorted_edge_index.clone()) + (i*n))
if self.MLP:
data = torch.stack(x_arr,dim=0), data.y
else:
data.x = torch.cat(x_arr,0).detach()
data.edge_index = torch.cat(e_arr,1).detach()
return data
  • nbody

群:与3D点云的一致(利用结点特征计算协方差矩阵,得到仿射变换);同样可以证明这个群里的元素和置换群中的元素可交换,因此如果在一个Sn-invariant的backbone(映射)上运用F(X),得到的映射仍然是Sn-invariant的。

backbone:massage-passing-GNN,是Sn-invariant的,,构造backbone为

输入:3+3维,表示初始位置和初始速度。从附录来看更新结点的时候还要考虑二者相互排斥还是相互吸引。φe和φh分别是两个MLP。

代码解释:

FA_GNN网络架构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class FA_GNN(nn.Module):
def __init__(self, input_dim, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, attention=0, recurrent=False):
super(FA_GNN, self).__init__()
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
self.dimension_reduce = nn.ModuleList()
# GCL就是MLP:线性层+激活层。但是这个edge_in_nf是什么呢?
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=2, act_fn=act_fn, attention=attention, recurrent=recurrent))
self.decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, 3))

self.embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf))
self.to(self.device)

def forward(self, h, edges, edge_attr=None):
n_frame = 8
n_nodes = 5
batch_size = int(h.shape[0]/n_nodes)
n_h = h.shape[0]
edges = expand_edge(edges, n_h, n_frame)
edge_attr = edge_attr.repeat(n_frame,1)

# 原始的数据施加群作用,得到h,h再送入embedding得到嵌入
h, F_ops, center = create_frame(h, n_nodes)
h = self.embedding(h)

# 看起来目的是让每一层都等变/不变。每一层都先过models.i层,得到h,然后计算这一层施加群作用的h
for i in range(self.n_layers):
h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr)
if i < (self.n_layers - 1):
# transform equiv features extraction
h = invert_latent_frame(h, F_ops, batch_size, n_nodes, None)
# compute new frame
h, F_ops, _ = create_latent_frame(h, n_nodes)


h = self.decoder(h)
h = invert_frame(h, F_ops, n_nodes, center)
return h

create_frame函数:输入是点的特征和点的数量,输出是施加了群作用的点特征(群作用指的是 协方差矩阵的特征值对应的八个旋转矩阵)

create_framecreate_latent_frame的主要区别是,create_frame是输入数据进入网络之前进行的处理,latent则是在网络过程中进行的处理,因此输入数据的维度存在一些区别。(bn6和bn3)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def create_frame(nodes, n_nodes):
pnts = nodes[:,:3]
v = nodes[:,3:]
pnts = pnts.view(-1, n_nodes, 3).transpose(1,2)
v = v.view(-1, n_nodes, 3).transpose(1,2)
center = pnts.mean(2,True)
pnts_centered = pnts - center
# add noise
R = torch.bmm(pnts_centered,pnts_centered.transpose(1,2)) # 协方差矩阵
lambdas,V_ = torch.symeig(R.detach().cpu(),True)
F = V_.to(R) # 这行在干嘛?to可以设备转换,也可以类型转换,结合上面detach来看应该是device
ops = torch.tensor([[1,1,1],
[1,1,-1],
[1,-1,1],
[1,-1,-1],
[-1,1,1],
[-1,1,-1],
[-1,-1,1],
[-1,-1,-1]]).unsqueeze(1).to(F)
F_ops = ops.unsqueeze(0) * F.unsqueeze(1)

# 爱因斯坦求和:b*8*3*3.(交换最后两维) @ b*n_nodes*3.(交换最后两维)= b*8*n_nodes*3
# b*8个 3*3(转置后)中,每个3*3都以行向量为特征向量;
# b 个 3*n_nodes中,每列都是一个点的中心化坐标/速度
framed_input = torch.einsum('boij,bpj->bopi',F_ops.transpose(2,3),(pnts - center).transpose(1,2))
framed_v = torch.einsum('boij,bpj->bopi',F_ops.transpose(2,3),(v).transpose(1,2))

framed_input = framed_input.transpose(0,1)
framed_input = torch.reshape(framed_input,(-1,3))
framed_v = framed_v.transpose(0,1)
framed_v = torch.reshape(framed_v,(-1,3))
out = torch.cat([framed_input,framed_v],dim=1)
return out, F_ops.detach(), center.detach()

invert_frame函数:矩阵(群作用)作用于每一层的输出上

1
2
3
4
5
6
7
8
9
def invert_frame(pnts, F_ops, n_nodes, center):
pnts = pnts.view(8, -1, n_nodes,3)
pnts = pnts.transpose(0,1)
framed_input = torch.einsum('boij,bopj->bopi',F_ops, pnts)
framed_input = framed_input.mean(1)
if center is not None:
framed_input = framed_input + center.transpose(1,2)
framed_input = torch.reshape(framed_input,(-1,3))
return framed_input

invert_latent_frame函数:这个是在decoder之前用的。

1
2
3
4
5
6
7
8
9
10
def invert_latent_frame(pnts, F_ops, batch_size, n_nodes, center):
pnts = pnts.view(8, batch_size, n_nodes, -1,3)
pnts = pnts.transpose(0,1)
framed_input = torch.einsum('boij,bopfj->bopfi',F_ops, pnts)
framed_input = framed_input.mean(1)
if center is not None:
framed_input = framed_input + center.transpose(1,2).unsqueeze(-2)
framed_input = framed_input.contiguous()
framed_input = framed_input.view(batch_size,-1,3)
return framed_input