import torch
import torch_geometric


class GNN(torch.nn.Module):
	def __init__(self):
		super().__init__()
		self.conv1=torch_geometric.nn.GATv2Conv(15, 16, heads=8, edge_dim=1, add_self_loops=False, dropout=0.25).jittable()
		self.conv2=torch_geometric.nn.GATv2Conv(16*8, 16, heads=8, edge_dim=1, add_self_loops=False, dropout=0.25).jittable()
		self.conv3=torch_geometric.nn.GATv2Conv(16*8, 16, heads=8, edge_dim=1, add_self_loops=False, dropout=0.25).jittable()
		self.conv4=torch_geometric.nn.GATv2Conv(16*8, 8, heads=8, edge_dim=1, add_self_loops=False, dropout=0.25).jittable()
		self.lin1=torch.nn.Linear(8*8, 1)
	
	def forward(self, x, edge_index, edge_attr):
		x=self.conv1(x, edge_index, edge_attr)
		x=torch.nn.functional.elu(x)
		x=self.conv2(x, edge_index, edge_attr)
		x=torch.nn.functional.elu(x)
		x=self.conv3(x, edge_index, edge_attr)
		x=torch.nn.functional.elu(x)
		x=self.conv4(x, edge_index, edge_attr)
		return self.lin1(x)
