-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathDynamicReductionNetwork.py
More file actions
69 lines (63 loc) · 2.71 KB
/
Copy pathDynamicReductionNetwork.py
File metadata and controls
69 lines (63 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from .DynamicReductionNetworkJit import DynamicReductionNetworkJit
class DynamicReductionNetwork(nn.Module):
'''
This model iteratively contracts nearest neighbour graphs
until there is one output node.
The latent space trained to group useful features at each level
of aggregration.
This allows single quantities to be regressed from complex point counts
in a location and orientation invariant way.
One encoding layer is used to abstract away the input features.
@param input_dim: dimension of input features
@param hidden_dim: dimension of hidden layers
@param output_dim: dimension of output
@param k: size of k-nearest neighbor graphs
@param aggr: message passing aggregation scheme.
@param norm: feature normaliztion. None is equivalent to all 1s (ie no scaling)
@param loop: boolean for presence/absence of self loops in k-nearest neighbor graphs
@param pool: type of pooling in aggregation layers. Choices are 'add', 'max', 'mean'
@param agg_layers: number of aggregation layers. Must be >=0
@param mp_layers: number of layers in message passing networks. Must be >=1
@param in_layers: number of layers in inputnet. Must be >=1
@param out_layers: number of layers in outputnet. Must be >=1
'''
def __init__(self, input_dim=4, hidden_dim=64, output_dim=1, k=16, aggr='add', norm=None,
loop=True, pool='max',
agg_layers=2, mp_layers=2, in_layers=1, out_layers=3,
graph_features = False,
latent_probe=None,
actually_jit=True,
):
super(DynamicReductionNetwork, self).__init__()
# DRN = DynamicReductionNetworkJit
drn = DRN(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
k=k,
aggr=aggr,
norm=norm,
agg_layers=agg_layers,
mp_layers=mp_layers,
in_layers=in_layers,
out_layers=out_layers,
graph_features=graph_features,
latent_probe=latent_probe
)
if actually_jit:
self.drn = torch.jit.script(drn)
else:
self.drn = drn
def forward(self, data):
'''
Push the batch 'data' through the network
'''
return self.drn(
data.x,
data.batch if hasattr(data, 'batch') else torch.zeros((data.x.shape()[0], ),
dtype=torch.int64,
device=x.device),
data.graph_x if hasattr(data, 'graph_x') else None,
)