文献阅读 等变图神经网络 Frame Averaging ICLR 2022 ORAL
目录:
概念说明和理论证明
实验细节设置、图的背景知识和代码分析
补充知识:
WL算法判断图同构
GIN网络
message-passing-GNN
参考资料:
【GNN】WL-test:GNN 的性能上界 (qq.com)
KNN算法(k近邻算法)原理及总结-CSDN博客
搞懂DGCNN,这篇就够了!论文及代码完全解析 - 知乎 (zhihu.com)
相关文献:
SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks
Vector Neurons: A General Framework for SO(3)-Equivariant Networks(矢量神经元?)
On the Universality of Rotation Equivariant Point Cloud Networks ICLR2020
GemNet: Universal Directional Graph Neural Networks for Molecules NIPS2021
Equiformer: Equivariant Graph Attention Transformer for 3D Atomistic Graphs ICLR2022
等变性:称函数 在群空间G上具有等变性,若 其中, 是输入和输出空间的群表示。
群表示:把一个群通过群同态,映射到一个可逆线性变换(可逆矩阵)中。
赋范线性空间:定义了范数的线性空间。赋范线性空间的性质比距离空间强很多,因为它上面一定可以定义距离
映射 。考虑一个群 ,通过群表示; 和; 作用于V,W上。
目标:让 成为“不变的”函数,即 ;让 成为等变的函数,即 。这里之所以一个是映射到R,一个是映射到W,是因为前者尝尝对应于分类问题,后者则经常对应于点云处理等,需要输出以(和输入)相同方式旋转的问题。
之前的研究说明,通过数据在整个G上进行变换,然后取
就可以得到不变或者等变的网络。然而,FA只用了G的一个子集。
Def1. 一个frame被定义为一个取值为集合的函数,将输入空间V映射到G的子集的集合,也就是每个输入空间的元素x映到一个G的子集。其具有两个性质(注意,下面的F(X)是一个集合,里面的每个元素都是G中的元素)
G-equivariant 。其中gF(X)就是g对F(x)中每个元素进行左乘后得到的元素集合,上式等号指代集合意义的相等。
有界性 通过X定义的frame,通过frame里的元素g定义的ρ(g),ρ(g)是有界的
因此可以用这个G的子集,来定义我们的平均化操作
Theo1.这样得到的frame是具有G-不变/等变的。也就是只用这个F(x)进行的映射构造,和利用整个G进行构造,效果相同。
不变性的证明: $$ \text{prove of: }<\phi>F(\rho_1(g’)X) = <\phi>F(X)\ \text{ 即证明利用帧平均得到的映射,对于G中每一个元素都是不变的。} \forall g’\in G, \ \begin{aligned}<\phi>F(\rho_1(g’)X)=&\frac{1}{F(X)}\sum {g\in g’F(X)}\phi(\rho_1(g)^{-1}\rho_1(g’)X) \quad &\text{(定义)}\ =&\frac{1}{F(X)}\sum {g\in F(X)}\phi(\rho_1(g’g)^{-1}\rho_1(g’)X) \quad &\ =&\frac{1}{F(X)}\sum {g\in F(X)}\phi(\rho_1(g)^{-1}X)\quad &\text{(群同态保持乘法)}\ =&<\phi>_F(X)& \end{aligned} $$ 第二个等号是因为,F(X)中每个元素是g,让其中的每个元素g左乘g’,就相当于一个g’F(X)的元素。第三个等号是因为群同态保持乘法,而g’g的逆是g逆g’逆,因为逆元唯一。
等变性的证明是类似的:
不变映射,是等变映射在选择 W = R 和平凡表示 ρ2(g) ≡ 1 时的特例。
进一步地,如果某个网络架构已经对于某个对称群H具有了不变性或等变性,则可以将这种不变性或等变性扩展到更大的对称群H×G。
预定义:设 是G和H在V上的表示, 同理是GH在W上的表示。
称表示 是可交换的,如果对于所有的g,h,X(属于V)有 。如果他们可交换,则可以定义GH上的一个表示
F(X)对于H来说是不变的,即
Theo2. 设F是H-不变,G-等变的:
如果 是H-不变的,表示 是可交换的,那么 是G×H不变的
如果 是H-等变的,表示 是可交换的,那么 是G×H等变的
也可以通过右乘来定义等变性/不变性。
更加高效地计算不变性frame:
稳定子群:作用于X上而保持X不变的元素。
等价关系:如果 中两个元素可以通过 中的元素相互转换, ,则它们是等价的。所有和g等价的F(X)中元素记为等价类,也就是一个轨道[g]。
说明这个F(X)大小至少是X处稳定子群的大小。另外由于[g]轨道上每个元素对于X作用后,通过 会产生一样的效果:
因此只需要每个轨道挑一个出来计算/构造不变映射即可。(如何挑出不同的轨道?这个过程是否仍然需要计算?还是说依赖于具体的群的类型)
好像下面这个推论是对上面的问题的回答。既然每条轨道大小相等,那么我在所有轨道组成的并集也就是整个F(X)上随机取元素,取出的元素在每个轨道上的概率相等。
这个过程中唯一的问题就是需要计算出轨道的数量,我推测这里只要能算出frame的大小|F(X)|和稳定子群 的大小就行了。可以用抽样方法估计稳定子群的大小,即随机找一些g看他们作用于X之后X是否保持不变。
运用了FA的映射,表达能力得到了保证。
Let be an arbitrary equivariant function, a bounded equivariant frame
over a frame-finite domain Let be the constant from Definition 1. For arbitrary , $$ \begin{aligned}\left|\Psi(X)-\langle\Phi\rangle_{\mathcal{F}}(X)\right|{W}&=\left|\langle\Psi\rangle {\mathcal{F}}(X)-\langle\Phi\rangle_{\mathcal{F}}(X)\right|{W}\&\leq\frac{1}{|\mathcal{F}(X)|}\sum {g\in\mathcal{F}(X)}\left|\rho_{2}(g)\Psi(\rho_{1}(g)^{-1}X)-\rho_{2}(g)\Phi(\rho_{1}(g)^{-1}X)\right|{W}\&\leq\max {g\in\mathcal{F}(X)}\left|\rho_{2}(g)\right|{\mathrm{op}}\left|\Psi-\Phi\right| {K_{F},W}\&\leq c\left|\Psi-\Phi\right|{K {\mathcal{F},W}}\end{aligned} $$
where in the first equality we used the fact that since is already G equivariant.
实验部分: 重点看了后两个实验的设置和代码
群:rotation & reflection 群O(d) x translation 群T(d),或者SE(d)=SO(d)xT(d),其中SO(d)只包含旋转
Frame:通过PCA定义。的 行 ( 第 一 维 ) 是 数 据 量 , 第 二 维 是 特 征 维 度 是 一 个 列 向 量 , 每 个 维 度 记 录 中 所 有 数 据 在 这 个 维 度 的 均 值 协 方 差 矩 阵 有 特 征 值 , 对 应 特 征 向 量 或 这里假设协方差矩阵具有simple spectrum,也就是非重复特征值,“几乎处处”可以定义F(X)
F也是Sn不变的。
这一块是最难看懂的,涉及对图的谱表示,WL可区分性及GIN对WL的性能逼近等
背景知识补充1:GIN网络和图同构问题的算法性能上界
对图同构问题(NP问题)目前最有效的算法是Weisfeiler-Lehman 算法,可以在准多项式时间求解。WL算法解决的问题是:比较两个图是否同构。概括来说,每次迭代中,将一个点表示成**{这个点的编号:周围的点的编号}**的形式(这里的“编号”相当于点的种类),然后为每一种表示构建唯一的hash编码。迭代多轮或者图收敛后,停止迭代。
WL算法和图网络类似:每轮都是聚合+更新。
WL-test算法是GNN的性能上界。因为可以证明:如果能用GNN区分两个图是否同构,则WL-test一定可以。证明要点在于WL-test算法的单射性质 。
构造出一个图网络,逼近这个上界。构造出的网络就是GIN。论文中定理指出了,图神经网络在满足这些条件时,能够将WL测试认为是非同构图的两个图和映射到不同的embedding
①按照 进行每轮的aggregate;
②作用在图级的READOUT函数也是单射的
后面涉及了一些关于可数集函数映射和利用MLP逼近表示的定理,没有细看。总之最后得到的每层AGGREGATE 更新方式是:READOUT 为对每次迭代得到的所有节点的特征求和得到该轮迭代的图特征,然后再拼接起每一轮迭代的图特征来得到最终的图特征 相比其他图网络的单层传播,多层感知机拟合能力强;相比mean和max,sum不会混淆一些结构。
代码说明:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class GinNet (nn.Module): def __init__ (self ): super (GinNet, self ).__init__() neuron=64 r1=np.random.uniform() r2=np.random.uniform() r3=np.random.uniform() nn1 = Sequential(Linear(dataset.num_features, neuron)) self .conv1 = GINConv(nn1,eps=r1,train_eps=True ) nn2 = Sequential(Linear(neuron, neuron)) self .conv2 = GINConv(nn2,eps=r2,train_eps=True ) nn3 = Sequential(Linear(neuron, neuron)) self .conv3 = GINConv(nn3,eps=r3,train_eps=True ) self .fc1 = torch.nn.Linear(neuron, 10 ) def forward (self, data ): x=data.x edge_index=data.edge_index x = torch.tanh(self .conv1(x, edge_index)) x = torch.tanh(self .conv2(x, edge_index)) x = torch.tanh(self .conv3(x, edge_index)) x = global_add_pool(x, data.batch) x = torch.tanh(self .fc1(x)) return x
1 2 3 4 5 6 7 8 9 10 def genrate_perm (perm_idx ): perm = [np.random.permutation(perm) for perm in perm_idx] return perm def generate_A (edge_index, n ): A=np.zeros((n,n),dtype=np.float32) A[edge_index[0 ],edge_index[1 ]]=1 return torch.from_numpy(A.flatten())
sort_fn_laplacian
的作用是根据拉普拉斯矩阵的特征值,返回节点排序索引和对应的特征向量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def sort_fn_laplacian (x,edge_index ): L_e, L_w = torch_geometric.utils.get_laplacian(edge_index) L = np.zeros((x.shape[0 ],x.shape[0 ]),dtype=np.float32) L[L_e[0 ],L_e[1 ]]=L_w evals, evecs = np.linalg.eigh(L) unique_vals, evals_idx, evals_mult = np.unique(evals, return_counts=True , return_index=True ) chosen_evecs = [] for ii in range (len (evals_idx)): if evals_mult[ii] == 1 : chosen_evecs.append(np.abs (evecs[:,evals_idx[ii]])) else : eigen_space_start_idx = evals_idx[ii] eigen_space_size = evals_mult[ii] eig_space_basis = evecs[:, eigen_space_start_idx:(eigen_space_start_idx+eigen_space_size)] chosen_evecs.append(np.sqrt((eig_space_basis ** 2 ).sum (1 ))) chosen_evecs = np.stack(chosen_evecs, axis=1 ).round (decimals=2 ) sort_idx = np.lexsort([col for col in chosen_evecs.transpose()[::-1 ]]) return sort_idx, chosen_evecs
SortFrame
类是用于处理和转换图数据,按照基于拉普拉斯矩阵的特征向量对图中的节点进行排序。传入的data首先需要经过pre_transform,然后进行结点和边的重排序(使用sort_fn_laplacian),返回排序后的data。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 class SortFrame (object ): def __init__ (self,device, pre_transform,sort_fn=sort_fn_laplacian ): self .pre_transform = pre_transform self .sort_fn = sort_fn self .device = device def __call__ (self, data ): data = self .pre_transform(data) sort_idx, to_sort = self .sort_fn(data.x, data.edge_index) sorted_x = to_sort[sort_idx,:] unique_rows, dup_rows_idx, dup_rows_mult = np.unique(sorted_x, axis=0 , return_index=True , return_counts=True ) perm_start_idx = dup_rows_idx[dup_rows_mult!=1 ] perm_size = dup_rows_mult[dup_rows_mult!=1 ] perm_idx = [] for ii in range (len (perm_size)): perm_idx.append(np.arange(perm_start_idx[ii], perm_start_idx[ii]+perm_size[ii])) data.perm_idx = perm_idx data.sort_idx = sort_idx data.size = data.x.shape[0 ] return data
SampleFrame8C
类:生成一组新的图数据样本
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 class SampleFrame8C (object ): def __init__ (self,size=64 ,sample_size=10 , GA=False ,MLP=False ): self .sample_size = sample_size self .size = size self .counter=0 self .GA=GA self .id = not MLP self .MLP = MLP def apply_permutation (self, perm_idx, edge_index, x ): inv_perm_idx = np.zeros_like(perm_idx) inv_perm_idx[perm_idx] = np.arange(perm_idx.shape[0 ]) sorted_edge_index = torch.tensor(inv_perm_idx[edge_index]) sorted_x = x[perm_idx,:] return sorted_edge_index, sorted_x def permute_with_perm (self,perm_idx, perm, sort_idx, edge_index, x ): inv_sort_idx = np.zeros_like(sort_idx) inv_sort_idx[sort_idx] = np.arange(sort_idx.shape[0 ]) inv_sort_idx[sort_idx[list (itertools.chain(*perm_idx))]] = list (itertools.chain(*perm)) sorted_edge_index = inv_sort_idx[edge_index] cur_sort_idx = np.zeros_like(sort_idx) cur_sort_idx[inv_sort_idx] = np.arange(sort_idx.shape[0 ]) sorted_x_feat = x[cur_sort_idx,:] return sorted_edge_index, sorted_x_feat def __call__ (self,data ): n = data.x.shape[0 ] m = self .size x = data.x d = x.shape[1 ] edge_index = data.edge_index sort_idx = data.sort_idx perm_idx = data.perm_idx x_arr = [] e_arr = [] if not perm_idx: if self .GA: sorted_edge_index, sorted_x = self .apply_permutation(np.random.permutation(n), edge_index.clone(), x.clone()) else : sorted_edge_index, sorted_x = self .apply_permutation(sort_idx, edge_index.clone(), x.clone()) data.edge_index = sorted_edge_index.detach() if self .MLP: new_x = generate_A(sorted_edge_index,m).unsqueeze(0 ) data = new_x,data.y else : if self .id : data.x = torch.cat([sorted_x.detach(),torch.eye(n, dtype=x.dtype),torch.zeros((n,m-n), dtype=x.dtype)],1 ).clone() else : data.x = sorted_x.detach() data.edge_index = sorted_edge_index else : for i in range (self .sample_size): if self .GA: sorted_edge_index, sorted_x = self .apply_permutation(np.random.permutation(n), edge_index.clone(), x.clone()) else : perm = genrate_perm(perm_idx) sorted_edge_index, sorted_x = self .permute_with_perm(perm_idx, perm, sort_idx, edge_index.clone(), x.clone()) sorted_edge_index = torch.from_numpy(sorted_edge_index) if self .MLP: x_arr.append(generate_A(sorted_edge_index,m)) else : if self .id : x_arr.append(torch.cat([sorted_x.detach(),torch.eye(n, dtype=x.dtype),torch.zeros((n,m-n), dtype=x.dtype)],1 ).clone()) else : x_arr.append(sorted_x.detach()) e_arr.append(torch.tensor(sorted_edge_index.clone()) + (i*n)) if self .MLP: data = torch.stack(x_arr,dim=0 ), data.y else : data.x = torch.cat(x_arr,0 ).detach() data.edge_index = torch.cat(e_arr,1 ).detach() return data
群:与3D点云的一致(利用结点特征计算协方差矩阵,得到仿射变换);同样可以证明这个群里的元素和置换群中的元素可交换,因此如果在一个Sn-invariant的backbone(映射)上运用F(X),得到的映射仍然是Sn-invariant的。
backbone:massage-passing-GNN,是Sn-invariant的, ,构造backbone为 。
输入:3+3维,表示初始位置和初始速度。从附录来看更新结点的时候还要考虑二者相互排斥还是相互吸引。φe和φh分别是两个MLP。
代码解释:
FA_GNN
网络架构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class FA_GNN (nn.Module): def __init__ (self, input_dim, hidden_nf, device='cpu' , act_fn=nn.SiLU( ), n_layers=4 , attention=0 , recurrent=False ): super (FA_GNN, self ).__init__() self .hidden_nf = hidden_nf self .device = device self .n_layers = n_layers self .dimension_reduce = nn.ModuleList() for i in range (0 , n_layers): self .add_module("gcl_%d" % i, GCL(self .hidden_nf, self .hidden_nf, self .hidden_nf, edges_in_nf=2 , act_fn=act_fn, attention=attention, recurrent=recurrent)) self .decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf), act_fn, nn.Linear(hidden_nf, 3 )) self .embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf)) self .to(self .device) def forward (self, h, edges, edge_attr=None ): n_frame = 8 n_nodes = 5 batch_size = int (h.shape[0 ]/n_nodes) n_h = h.shape[0 ] edges = expand_edge(edges, n_h, n_frame) edge_attr = edge_attr.repeat(n_frame,1 ) h, F_ops, center = create_frame(h, n_nodes) h = self .embedding(h) for i in range (self .n_layers): h, _ = self ._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr) if i < (self .n_layers - 1 ): h = invert_latent_frame(h, F_ops, batch_size, n_nodes, None ) h, F_ops, _ = create_latent_frame(h, n_nodes) h = self .decoder(h) h = invert_frame(h, F_ops, n_nodes, center) return h
create_frame
函数:输入是点的特征和点的数量,输出是施加了群作用的点特征(群作用指的是 协方差矩阵的特征值对应的八个旋转矩阵)
create_frame
和create_latent_frame
的主要区别是,create_frame是输入数据进入网络之前进行的处理,latent则是在网络过程中进行的处理,因此输入数据的维度存在一些区别。(bn 6和bn 3)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 def create_frame (nodes, n_nodes ): pnts = nodes[:,:3 ] v = nodes[:,3 :] pnts = pnts.view(-1 , n_nodes, 3 ).transpose(1 ,2 ) v = v.view(-1 , n_nodes, 3 ).transpose(1 ,2 ) center = pnts.mean(2 ,True ) pnts_centered = pnts - center R = torch.bmm(pnts_centered,pnts_centered.transpose(1 ,2 )) lambdas,V_ = torch.symeig(R.detach().cpu(),True ) F = V_.to(R) ops = torch.tensor([[1 ,1 ,1 ], [1 ,1 ,-1 ], [1 ,-1 ,1 ], [1 ,-1 ,-1 ], [-1 ,1 ,1 ], [-1 ,1 ,-1 ], [-1 ,-1 ,1 ], [-1 ,-1 ,-1 ]]).unsqueeze(1 ).to(F) F_ops = ops.unsqueeze(0 ) * F.unsqueeze(1 ) framed_input = torch.einsum('boij,bpj->bopi' ,F_ops.transpose(2 ,3 ),(pnts - center).transpose(1 ,2 )) framed_v = torch.einsum('boij,bpj->bopi' ,F_ops.transpose(2 ,3 ),(v).transpose(1 ,2 )) framed_input = framed_input.transpose(0 ,1 ) framed_input = torch.reshape(framed_input,(-1 ,3 )) framed_v = framed_v.transpose(0 ,1 ) framed_v = torch.reshape(framed_v,(-1 ,3 )) out = torch.cat([framed_input,framed_v],dim=1 ) return out, F_ops.detach(), center.detach()
invert_frame
函数:矩阵(群作用)作用于每一层的输出上
1 2 3 4 5 6 7 8 9 def invert_frame (pnts, F_ops, n_nodes, center ): pnts = pnts.view(8 , -1 , n_nodes,3 ) pnts = pnts.transpose(0 ,1 ) framed_input = torch.einsum('boij,bopj->bopi' ,F_ops, pnts) framed_input = framed_input.mean(1 ) if center is not None : framed_input = framed_input + center.transpose(1 ,2 ) framed_input = torch.reshape(framed_input,(-1 ,3 )) return framed_input
invert_latent_frame
函数:这个是在decoder之前用的。
1 2 3 4 5 6 7 8 9 10 def invert_latent_frame (pnts, F_ops, batch_size, n_nodes, center ): pnts = pnts.view(8 , batch_size, n_nodes, -1 ,3 ) pnts = pnts.transpose(0 ,1 ) framed_input = torch.einsum('boij,bopfj->bopfi' ,F_ops, pnts) framed_input = framed_input.mean(1 ) if center is not None : framed_input = framed_input + center.transpose(1 ,2 ).unsqueeze(-2 ) framed_input = framed_input.contiguous() framed_input = framed_input.view(batch_size,-1 ,3 ) return framed_input