import torch import math import torch.nn as nn from rdkit import Chem from rdkit import rdBase rdBase.DisableLog('rdApp.*') # Split SMILES into words def split(sm): ''' function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc. input: A SMILES output: A string with space between words ''' arr = [] i = 0 while i < len(sm)-1: if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \ 'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']: arr.append(sm[i]) i += 1 elif sm[i]=='%': arr.append(sm[i:i+3]) i += 3 elif sm[i]=='C' and sm[i+1]=='l': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='C' and sm[i+1]=='a': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='C' and sm[i+1]=='u': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='B' and sm[i+1]=='r': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='B' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='B' and sm[i+1]=='a': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='B' and sm[i+1]=='i': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='S' and sm[i+1]=='i': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='S' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='S' and sm[i+1]=='r': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='N' and sm[i+1]=='a': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='N' and sm[i+1]=='i': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='R' and sm[i+1]=='b': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='R' and sm[i+1]=='a': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='X' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='L' and sm[i+1]=='i': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='A' and sm[i+1]=='l': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='A' and sm[i+1]=='s': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='A' and sm[i+1]=='g': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='A' and sm[i+1]=='u': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='M' and sm[i+1]=='g': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='M' and sm[i+1]=='n': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='T' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='Z' and sm[i+1]=='n': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='s' and sm[i+1]=='i': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='s' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='t' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='H' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='+' and sm[i+1]=='2': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='+' and sm[i+1]=='3': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='+' and sm[i+1]=='4': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='-' and sm[i+1]=='2': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='-' and sm[i+1]=='3': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='-' and sm[i+1]=='4': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='K' and sm[i+1]=='r': arr.append(sm[i:i+2]) i += 2 elif sm[i]=='F' and sm[i+1]=='e': arr.append(sm[i:i+2]) i += 2 else: arr.append(sm[i]) i += 1 if i == len(sm)-1: arr.append(sm[i]) return ' '.join(arr) # 活性化関数 class GELU(nn.Module): def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) # 位置情報を考慮したFFN class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.activation = GELU() def forward(self, x): return self.w_2(self.dropout(self.activation(self.w_1(x)))) # 正規化層 class LayerNorm(nn.Module): def __init__(self, features, eps=1e-6): super(LayerNorm, self).__init__() self.a_2 = nn.Parameter(torch.ones(features)) self.b_2 = nn.Parameter(torch.zeros(features)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 class SublayerConnection(nn.Module): def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm = LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x))) # Sample SMILES from probablistic distribution def sample(msms): ret = [] for msm in msms: ret.append(torch.multinomial(msm.exp(), 1).squeeze()) return torch.stack(ret) def validity(smiles): loss = 0 for sm in smiles: mol = Chem.MolFromSmiles(sm) if mol is None: loss += 1 return 1-loss/len(smiles)