import glob
import os
import pickle
import re
import sys
from functools import partial
import warnings
import h5py
import numpy as np
from tqdm import tqdm
import pdb2sql
from deeprank import config
from deeprank.config import logger
from deeprank.generate import MinMaxParam, NormalizeData, NormParam
from deeprank.tools import sparse
# import torch.utils.data as data_utils
# The class used to subclass data_utils.Dataset
# but that conflict with Sphinx that couldn't build the API
# It's apparently not necessary though and works without subclassing
[docs]class DataSet():
def __init__(self, train_database, valid_database=None, test_database=None,
chain1='A', chain2='B',
mapfly=True, grid_info=None,
use_rotation=None,
select_feature='all', select_target='DOCKQ',
normalize_features=True, normalize_targets=True,
target_ordering=None,
dict_filter=None, pair_chain_feature=None,
transform_to_2D=False, projection=0,
clip_features=True, clip_factor=1.5,
rotation_seed=None,
tqdm=False,
process=True):
'''Generates the dataset needed for pytorch.
This class hanldes the data generated by deeprank.generate to be
used in the deep learning part of DeepRank.
Args:
train_database (list(str)): names of the hdf5 files used for
the training/validation.
Example: ['1AK4.hdf5','1B7W.hdf5',...]
valid_database (list(str)): names of the hdf5 files used for
the validation.
Example: ['1ACB.hdf5','4JHF.hdf5',...]
test_database (list(str)): names of the hdf5 files used for
the test.
Example: ['7CEI.hdf5']
chain1 (str): first chain ID, defaults to 'A'
chain2 (str): second chain ID, defaults to 'B'
mapfly (bool): do we compute the map in the batch
preparation or read them
grid_info(dict): grid information to map the feature.
If None the original grid points are used.
The dict contains:
- 'number_of_points", the shape of grid
- 'resolution', the resolution of grid, unit in A
Example:
{'number_of_points': [10, 10, 10], 'resolution': [3, 3, 3]}
use_rotation (int): number of rotations to use.
Example: 0 (use only original data)
Default: None (use all data of the database)
select_feature (dict or 'all', optional):
Select the features used in the learning.
if mapfly is True:
- {'AtomDensities': 'all', 'Features': 'all'}
- {'AtomicDensities': config.atom_vdw_radius_noH,
'Features': ['PSSM_*', 'pssm_ic_*']}
if mapfly is False:
- {'AtomDensities_ind': 'all', 'Feature_ind': 'all'}
- {'Feature_ind': ['PSSM_*', 'pssm_ic_*']}
Default: 'all'
select_target (str,optional): Specify required target.
Default: 'DOCKQ'
normalize_features (Bool, optional): normalize features or not
Default: True
normalize_targets (Bool, optional): normalize targets or not
Default: True
target_ordering (str): 'lower' (the lower the better) or
'higher' (the higher the better)
By default is not specified (None) and the code tries
to identify it. If identification fails 'lower' is used.
dict_filter (None or dict, optional): Specify if we filter
the complexes based on target values,
Example: {'IRMSD': '<4. or >10'}
(select complexes with IRMSD lower than 4 or larger than 10)
Default: None
pair_chain_feature (None or callable, optional):
method to pair features of chainA and chainB
Example: np.sum (sum the chainA and chainB features)
transform_to_2D (bool, optional):
Boolean to use 2d maps instead of full 3d
Default: False
projection (int): Projection axis from 3D to 2D:
Mapping: 0 -> yz, 1 -> xz, 2 -> xy
Default = 0
clip_features (bool, optional):
Remove too large values of the grid.
Can be needed for native complexes where the coulomb
feature might be too large
clip_factor (float, optional): the features are clipped at:
+/-mean + clip_factor * std
tqdm (bool, optional): Print the progress bar
process (bool, optional): Actually process the data set.
Must be set to False when reusing a model for testing
rotation_seed(int, optional): random seed for getting rotation
axis and angle.
Examples:
>>> from deeprank.learn import *
>>> train_database = '1ak4.hdf5'
>>> data_set = DataSet(train_database,
>>> valid_database = None,
>>> test_database = None,
>>> chain1='C',
>>> chain2='D',
>>> grid_info = {
>>> 'number_of_points': (10, 10, 10),
>>> 'resolution': (3, 3, 3)
>>> },
>>> select_feature = {
>>> 'AtomicDensities': 'all',
>>> 'Features': [
>>> 'PSSM_*', 'pssm_ic_*' ]
>>> },
>>> select_target='IRMSD',
>>> normalize_features = True,
>>> normalize_targets=True,
>>> pair_chain_feature=np.add,
>>> dict_filter={'IRMSD':'<4. or >10.'},
>>> process = True)
'''
# allow for multiple database
self.train_database = self._get_database_name(train_database)
# allow for multiple database
self.valid_database = self._get_database_name(valid_database)
# allow for multiple database
self.test_database = self._get_database_name(test_database)
# chainIDs
self.chain1 = chain1
self.chain2 = chain2
# pdb selection
self.use_rotation = use_rotation
# features/targets selection
self.select_feature = select_feature
self.select_target = select_target
# map generation
self.mapfly = mapfly
self.grid_info = grid_info
self._grid_shape = None
# data agumentation
if self.mapfly:
self.data_augmentation = use_rotation
if self.data_augmentation is None:
self.data_augmentation = 0
else:
self.data_augmentation = 0
# normalization conditions
self.normalize_features = normalize_features
self.normalize_targets = normalize_targets
# clip the data
self.clip_features = clip_features
self.clip_factor = clip_factor
# shape of the data
self.input_shape = None
self.data_shape = None
# the possible pairing of the ind features
self.pair_chain_feature = pair_chain_feature
# get the eventual projection
self.transform = transform_to_2D
self.proj2D = projection
# filter the dataset
self.dict_filter = dict_filter
# target ordered lower the better or higher the better
self._get_target_ordering(target_ordering)
# print the progress bar or not
self.tqdm = tqdm
# set random seed
self.rotation_seed = rotation_seed
# process the data
if process:
self.process_dataset()
[docs] @staticmethod
def _get_database_name(database):
"""Get the list of hdf5 database file names.
Args:
database(None, str or list(str)): hdf5 database name(s).
Returns:
list: hdf5 file names
"""
# make sure the return is only one data type
filenames = []
if database is not None:
if not isinstance(database, list):
database = [database]
for db in database:
filenames += glob.glob(db)
return filenames
[docs] def process_dataset(self):
"""Process the data set.
Done by default. However must be turned off when one want to
test a pretrained model. This can be done by setting
``process=False`` in the creation of the ``DataSet`` instance.
"""
logger.info('\n')
logger.info('=' * 40)
logger.info('=\t DeepRank Data Set')
logger.info('=')
logger.info('=\t Training data')
for f in self.train_database:
logger.info(f'=\t -> {f}')
logger.info('=')
if self.valid_database:
logger.info('=\t Validation data')
for f in self.valid_database:
logger.info(f'=\t -> {f}')
logger.info('=')
if self.test_database:
logger.info('=\t Test data')
for f in self.test_database:
logger.info(f'=\t -> {f}')
logger.info('=')
logger.info('=' * 40 + '\n')
sys.stdout.flush()
# check if the files are ok
self.check_hdf5_files(self.train_database)
if self.valid_database:
self.valid_database = self.check_hdf5_files(
self.valid_database)
if self.test_database:
self.test_database = self.check_hdf5_files(
self.test_database)
# create the indexing system
# alows to associate each mol to an index
# and get fname and mol name from the index
self.create_index_molecules()
# get the actual feature name
if self.mapfly:
self.get_raw_feature_name()
else:
self.get_mapped_feature_name()
# get the pairing
self.get_pairing_feature()
# get grid shape
self.get_grid_shape()
# get the input shape
self.get_input_shape()
# get renormalization factor
if self.normalize_features or self.normalize_targets or self.clip_features:
if self.mapfly:
self.compute_norm()
else:
self.get_norm()
logger.info('\n')
logger.info(" Data Set Info:")
logger.info(
f' Augmentation : {self.use_rotation} rotations')
logger.info(
f' Training set : {self.ntrain} conformations')
logger.info(
f' Validation set : {self.nvalid} conformations')
logger.info(
f' Test set : {self.ntest} conformations')
logger.info(f' Number of channels : {self.input_shape[0]}')
logger.info(f' Grid Size : {self.data_shape[1]}, '
f'{self.data_shape[2]}, {self.data_shape[3]}')
sys.stdout.flush()
def __len__(self):
"""Get the length of the dataset
Returns:
int: number of complexes in the dataset
"""
return len(self.index_complexes)
def __getitem__(self, index):
"""Get one item from its unique index.
Args:
index (int): index of the complex
Returns:
dict: {'mol':[fname,mol],'feature':feature,'target':target}
"""
fname, mol, angle, axis = self.index_complexes[index]
try:
if self.mapfly:
feature, target = self.map_one_molecule(
fname, mol, angle, axis)
else:
feature, target = self.load_one_molecule(fname, mol)
if self.clip_features:
feature = self._clip_feature(feature)
if self.normalize_features:
feature = self._normalize_feature(feature)
if self.normalize_targets:
target = self._normalize_target(target)
if self.pair_chain_feature:
feature = self.make_feature_pair(
feature, self.pair_chain_feature)
if self.transform:
feature = self.convert2d(feature, self.proj2D)
return {'mol': [fname, mol], 'feature': feature, 'target': target}
except:
raise
print('Unable to load molecule %s from %s' % (mol, fname))
[docs] @staticmethod
def check_hdf5_files(database):
"""Check if the data contained in the hdf5 file is ok."""
logger.info(" Checking dataset Integrity")
remove_file = []
for fname in database:
try:
f = h5py.File(fname, 'r')
mol_names = list(f.keys())
if len(mol_names) == 0:
warnings.warn(' -> %s is empty ' % fname)
remove_file.append(fname)
f.close()
except BaseException:
warnings.warn(' -> %s is corrputed ' % fname)
remove_file.append(fname)
for name in remove_file:
database.remove(name)
if remove_file:
logger.info(f'\t -> Empty or corrput databases are removed:\n'
f'{remove_file}')
return database
[docs] def create_index_molecules(self):
"""Create the indexing of each molecule in the dataset.
Create the indexing:
[('1ak4.hdf5,1AK4_100w),...,('1fqj.hdf5,1FGJ_400w)]
This allows to refer to one complex with its index in the list.
Raises:
ValueError: No aviable training data after filtering.
"""
logger.info("\n\n Processing data set:")
self.index_complexes = []
# Training dataset
desc = '{:25s}'.format(' Train dataset')
if self.tqdm:
data_tqdm = tqdm(self.train_database,
desc=desc, file=sys.stdout)
else:
logger.info(' Train dataset')
data_tqdm = self.train_database
sys.stdout.flush()
for fdata in data_tqdm:
if self.tqdm:
data_tqdm.set_postfix(mol=os.path.basename(fdata))
try:
fh5 = h5py.File(fdata, 'r')
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
# to speed up in case of no filtering:
if not self.dict_filter:
self.index_complexes = [[fdata, k, None, None] for k in mol_names]
else:
for k in mol_names:
if self.filter(fh5[k]):
self.index_complexes += [(fdata,
k, None, None)]
for irot in range(self.data_augmentation):
axis, angle = pdb2sql.transform.get_rot_axis_angle(
self.rotation_seed)
self.index_complexes += [
(fdata, k, angle, axis)]
fh5.close()
except Exception:
logger.exception(f'Ignore file: {fdata}')
self.ntrain = len(self.index_complexes)
self.index_train = list(range(self.ntrain))
if self.ntrain == 0:
raise ValueError(
'No avaiable training data after filtering')
# Validation dataset
if self.valid_database:
desc = '{:25s}'.format(' Validation dataset')
if self.tqdm:
data_tqdm = tqdm(self.valid_database,
desc=desc, file=sys.stdout)
else:
data_tqdm = self.valid_database
logger.info(' Validation dataset')
sys.stdout.flush()
for fdata in data_tqdm:
if self.tqdm:
data_tqdm.set_postfix(mol=os.path.basename(fdata))
try:
fh5 = h5py.File(fdata, 'r')
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata, k, None, None)
for k in mol_names]
fh5.close()
except Exception:
logger.exception(f'Ignore file: {fdata}')
self.ntot = len(self.index_complexes)
self.index_valid = list(range(self.ntrain, self.ntot))
self.nvalid = self.ntot - self.ntrain
# Test dataset
if self.test_database:
desc = '{:25s}'.format(' Test dataset')
if self.tqdm:
data_tqdm = tqdm(self.test_database,
desc=desc, file=sys.stdout)
else:
data_tqdm = self.test_database
logger.info(' Test dataset')
sys.stdout.flush()
for fdata in data_tqdm:
if self.tqdm:
data_tqdm.set_postfix(mol=os.path.basename(fdata))
try:
fh5 = h5py.File(fdata, 'r')
mol_names = list(fh5.keys())
mol_names = self._select_pdb(mol_names)
self.index_complexes += [(fdata, k, None, None)
for k in mol_names]
fh5.close()
except Exception:
logger.exception(f'Ignore file: {fdata}')
self.ntot = len(self.index_complexes)
self.index_test = list(
range(self.ntrain + self.nvalid, self.ntot))
self.ntest = self.ntot - self.ntrain - self.nvalid
[docs] def _select_pdb(self, mol_names):
"""Select complexes.
Args:
mol_names (list): list of complex names
Returns:
list: list of selected complexes
"""
fnames_original = list(
filter(lambda x: not re.search(r'_r\d+$', x), mol_names))
if self.use_rotation is not None:
fnames_augmented = []
# TODO if there is no augmentation data in dataaset,
# the fnames_augmented should be 0, should report it.
if self.use_rotation > 0:
for i in range(self.use_rotation):
fnames_augmented += list(filter(lambda x:
re.search('_r%03d$' % (i + 1), x), mol_names))
selected_mol_names = fnames_original + fnames_augmented
else:
selected_mol_names = fnames_original
else:
selected_mol_names = mol_names
sample_id = fnames_original[0]
num_rotations = len(list((filter(lambda x:
re.search(sample_id + '_r', x), mol_names))))
self.use_rotation = num_rotations
return selected_mol_names
[docs] def filter(self, molgrp):
"""Filter the molecule according to a dictionary, e.g.,
dict_filter={'DOCKQ':'>0.1', 'IRMSD':'<=4 or >10'}).
The filter is based on the attribute self.dict_filter
that must be either of the form: { 'name': cond } or None
Args:
molgrp (str): group name of the molecule in the hdf5 file
Returns:
bool: True if we keep the complex False otherwise
Raises:
ValueError: If an unsuported condition is provided
"""
if self.dict_filter is None:
return True
for cond_name, cond_vals in self.dict_filter.items():
try:
val = molgrp['targets/' + cond_name][()]
except KeyError:
warnings.warn(f'Filter {cond_name} not found for mol '
f'{molgrp.name}')
# if we have a string it's more complicated
if isinstance(cond_vals, str):
ops = ['>', '<', '==', '<=', '>=']
new_cond_vals = cond_vals
for o in ops:
new_cond_vals = new_cond_vals.replace(
o, 'val' + o)
if not eval(new_cond_vals):
return False
else:
raise ValueError(
'Conditions not supported', cond_vals)
return True
[docs] def get_mapped_feature_name(self):
"""Get actual mapped feature names for feature selections.
Note:
- class parameter self.select_feature examples:
- 'all'
- {'AtomicDensities_ind': 'all', 'Feature_ind':all}
- {'Feature_ind': ['PSSM_*', 'pssm_ic_*']}
- Feature type must be: 'AtomicDensities_ind' or 'Feature_ind'.
Raises:
KeyError: Wrong feature type.
KeyError: Wrong feature type.
"""
# open a h5 file in case we need it
f5 = h5py.File(self.train_database[0], 'r')
mol_name = list(f5.keys())[0]
mapped_data = f5.get(mol_name + '/mapped_features/')
chain_tags = ['_chain1', '_chain2']
# if we select all the features
if self.select_feature == "all":
# redefine dict
self.select_feature = {}
# loop over the feat types and add all the feat_names
for feat_type, feat_names in mapped_data.items():
self.select_feature[feat_type] = [
name for name in feat_names]
# if a selection was made
else:
# we loop over the input dict
for feat_type, feat_names in self.select_feature.items():
# if for a given type we need all the feature
if feat_names == 'all':
if feat_type in mapped_data:
self.select_feature[feat_type] = list(
mapped_data[feat_type].keys())
else:
self.print_possible_features()
raise KeyError('Feature type %s not found')
# if we have stored the individual
# chainA chainB data we need to expand the feature list
# however when we reload a pretrained model we already
# come with _chainA, _chainB features.
# So then we shouldn't add the tags
else:
# TODO to refactor this part
if feat_type not in mapped_data:
self.print_possible_features()
raise KeyError('Feature type %s not found')
self.select_feature[feat_type] = []
# loop over all the specified feature names
for name in feat_names:
# check if there is not _chainA or _chainB in the name
cond = [tag not in name for tag in chain_tags]
# if there is no chain tag in the name
if np.all(cond):
# if we have a wild card e.g. PSSM_*
# we check the matches and add them
if '*' in name:
match = name.split('*')[0]
possible_names = list(
mapped_data[feat_type].keys())
match_names = [
n for n in possible_names
if n.startswith(match)]
self.select_feature[feat_type] += match_names
# if we don't have a wild card we append
# <feature_name>_chainA and <feature_name>_chainB
# to the list
else:
self.select_feature[feat_type] += [
name + tag for tag in chain_tags]
# if there is a chain tag in the name
# (we probably relaod a pretrained model)
# and we simply append the feaature name
else:
self.select_feature[feat_type].append(
name)
f5.close()
[docs] def get_raw_feature_name(self):
"""Get actual raw feature names for feature selections.
Note:
- class parameter self.select_feature examples:
- 'all'
- {'AtomicDensities': 'all', 'Features':all}
- {'AtomicDensities': config.atom_vaw_radius_noH, 'Features': ['PSSM_*', 'pssm_ic_*']}
- Feature type must be: 'AtomicDensities' or 'Features'.
Raises:
KeyError: Wrong feature type.
KeyError: Wrong feature type.
"""
# open a h5 file in case we need it
f5 = h5py.File(self.train_database[0], 'r')
mol_name = list(f5.keys())[0]
raw_data = f5.get(mol_name + '/features/')
# if we select all the features
if self.select_feature == "all":
self.select_feature = {}
self.select_feature['AtomicDensities'] = config.atom_vdw_radius_noH
self.select_feature['Features'] = [
name for name in raw_data.keys()]
# if a selection was made
else:
# we loop over the input dict
for feat_type, feat_names in self.select_feature.items():
# if for a given type we need all the feature
if feat_names == 'all':
if feat_type == 'AtomicDensities':
self.select_feature['AtomicDensities'] = \
config.atom_vdw_radius_noH
elif feat_type == 'Features':
self.select_feature[feat_type] = list(
raw_data.keys())
else:
raise KeyError(
f'Wrong feature type {feat_type}. '
f'It should be "AtomicDensities" or "Features".')
else:
if feat_type == 'AtomicDensities':
assert isinstance(
self.select_feature['AtomicDensities'], dict)
elif feat_type == 'Features':
self.select_feature[feat_type] = []
for name in feat_names:
if '*' in name:
match = name.split('*')[0]
possible_names = list(raw_data.keys())
match_names = [
n for n in possible_names
if n.startswith(match)]
self.select_feature[feat_type] += match_names
else:
self.select_feature[feat_type] += [name]
else:
raise KeyError(
f'Wrong feature type {feat_type}. '
f'It should be "AtomicDensities" or "Features".')
f5.close()
[docs] def print_possible_features(self):
"""Print the possible features in the group."""
f5 = h5py.File(self.train_database[0], 'r')
mol_name = list(f5.keys())[0]
mapgrp = f5.get(mol_name + '/mapped_features/')
logger.info('\nPossible Features:')
logger.info('-' * 20)
for feat_type in list(mapgrp.keys()):
logger.info('== %s' % feat_type)
for fname in list(mapgrp[feat_type].keys()):
logger.info(' -- %s' % fname)
if self.select_feature is not None:
logger.info('\nYour selection was:')
for feat_type, feat in self.select_feature.items():
if feat_type not in list(mapgrp.keys()):
logger.info(
'== \x1b[0;37;41m' + feat_type + '\x1b[0m')
else:
logger.info('== %s' % feat_type)
if isinstance(feat, str):
logger.info(' -- %s' % feat)
if isinstance(feat, list):
for f in feat:
logger.info(' -- %s' % f)
logger.info("You don't need to specify _chainA _chainB for each feature. " +
"The code will append it automatically")
[docs] def get_pairing_feature(self):
"""Creates the index of paired features."""
if self.pair_chain_feature:
self.pair_indexes = []
start = 0
for feat_type, feat_names in self.select_feature.items():
nfeat = len(feat_names)
if '_ind' in feat_type:
self.pair_indexes += [
[i, i + 1] for i in range(start, start + nfeat, 2)]
else:
self.pair_indexes += [
[i] for i in range(start, start + nfeat)]
start += nfeat
[docs] def get_grid_shape(self):
"""Get the shape of the matrices.
Raises:
ValueError: If no grid shape is provided or is present in
the HDF5 file
"""
if self.mapfly is False:
fname = self.train_database[0]
fh5 = h5py.File(fname, 'r')
mol = list(fh5.keys())[0]
# get the mol
mol_data = fh5.get(mol)
# get the grid size
if 'grid_points' in mol_data:
nx = mol_data['grid_points']['x'].shape[0]
ny = mol_data['grid_points']['y'].shape[0]
nz = mol_data['grid_points']['z'].shape[0]
self._grid_shape = (nx, ny, nz)
fh5.close()
if self._grid_shape is None:
if self.grid_info is not None:
self._grid_shape = self.grid_info['number_of_points']
else:
raise ValueError(
f'Impossible to determine sparse grid shape.\n'
f'If you are not loading a pretrained model, '
f' specify argument "grid_info".')
[docs] def compute_norm(self):
"""compute the normalization factors."""
# logger.info(" Normalization factor:")
# loop over all the complexes in the database
first = True
for comp in tqdm(self.index_complexes):
fname, molname = comp[0], comp[1]
# get the feature/target
if self.mapfly:
feature, target = self.map_one_molecule(
fname, mol=molname)
else:
feature, target = self.load_one_molecule(
fname, mol=molname)
# create the norm isntances at the first passage
if first:
self.param_norm = {'features': [], 'targets': None}
for ifeat in range(feature.shape[0]):
self.param_norm['features'].append(NormParam())
self.param_norm['targets'] = MinMaxParam()
first = False
# update the norm instances
for ifeat, mat in enumerate(feature):
self.param_norm['features'][ifeat].add(
np.mean(mat), np.var(mat))
self.param_norm['targets'].update(target)
# process the std of the features and make array for fast access
nfeat, ncomplex = len(
self.param_norm['features']), len(self.index_complexes)
self.feature_mean, self.feature_std = [], []
for ifeat in range(nfeat):
# process the std and check
self.param_norm['features'][ifeat].process(ncomplex)
if self.param_norm['features'][ifeat].std == 0:
logger.info(' Final STD Null. Changed it to 1')
self.param_norm['features'][ifeat].std = 1
# store as array for fast access
self.feature_mean.append(
self.param_norm['features'][ifeat].mean)
self.feature_std.append(
self.param_norm['features'][ifeat].std)
self.target_min = self.param_norm['targets'].min[0]
self.target_max = self.param_norm['targets'].max[0]
logger.info(f'{self.target_min}, {self.target_max}')
[docs] def get_norm(self):
"""Get the normalization values for the features."""
# logger.info(" Normalization factor:")
# declare the dict of class instance
# where we'll store the normalization parameter
self.param_norm = {'features': {}, 'targets': {}}
for feat_type, feat_names in self.select_feature.items():
self.param_norm['features'][feat_type] = {}
for name in feat_names:
self.param_norm['features'][feat_type][name] = NormParam(
)
self.param_norm['targets'][self.select_target] = MinMaxParam()
# read the normalization
self._read_norm()
# make array for fast access
self.feature_mean, self.feature_std = [], []
for feat_type, feat_names in self.select_feature.items():
for name in feat_names:
self.feature_mean.append(
self.param_norm['features'][feat_type][name].mean)
self.feature_std.append(
self.param_norm['features'][feat_type][name].std)
self.target_min = self.param_norm['targets'][self.select_target].min
self.target_max = self.param_norm['targets'][self.select_target].max
[docs] def _read_norm(self):
"""Read or create the normalization file for the complex."""
# loop through all the filename
for f5 in self.train_database:
# get the precalculated data
fdata = os.path.splitext(f5)[0] + '_norm.pckl'
# if the file doesn't exist we create it
if not os.path.isfile(fdata):
logger.info(f" Computing norm for {f5}")
norm = NormalizeData(f5, shape=self._grid_shape)
norm.get()
# read the data
data = pickle.load(open(fdata, 'rb'))
# handle the features
for feat_type, feat_names in self.select_feature.items():
for name in feat_names:
mean = data['features'][feat_type][name].mean
var = data['features'][feat_type][name].var
if var == 0:
logger.info(
' : STD is null for %s in %s' % (name, f5))
self.param_norm['features'][feat_type][name].add(
mean, var)
# handle the target
minv = data['targets'][self.select_target].min
maxv = data['targets'][self.select_target].max
self.param_norm['targets'][self.select_target].update(
minv)
self.param_norm['targets'][self.select_target].update(
maxv)
# process the std
nfile = len(self.train_database)
for feat_types, feat_dict in self.param_norm['features'].items():
for feat in feat_dict:
self.param_norm['features'][feat_types][feat].process(
nfile)
if self.param_norm['features'][feat_types][feat].std == 0:
logger.info(
' Final STD Null for %s/%s. Changed it to 1' %
(feat_types, feat))
self.param_norm['features'][feat_types][feat].std = 1
[docs] def _get_target_ordering(self, order):
"""Determine if ordering of the target.
This can be lower the better or higher the better If it can't
determine the ordering 'lower' is assumed
"""
lower_list = ['IRMSD', 'LRMSD', 'HADDOCK']
higher_list = ['DOCKQ', 'Fnat']
NA_list = ['binary_class', 'BIN_CLASS', 'class']
if order is not None:
self.target_ordering = order
else:
if self.select_target in lower_list:
self.target_ordering = 'lower'
elif self.select_target in higher_list:
self.target_ordering = 'higher'
elif self.select_target in NA_list:
self.target_ordering = None
else:
warnings.warn(
' Target ordering unidentified. lower assumed')
self.target_ordering = 'lower'
[docs] def _normalize_target(self, target):
"""Normalize the values of the targets.
Args:
target (list(float)): raw data
Returns:
list(float): normalized data
"""
# TODO why define such normlised target?
target -= self.target_min
target /= self.target_max
return target
[docs] def _normalize_feature(self, feature):
"""Normalize the values of the features.
Args:
feature (np.array): raw feature values
Returns:
np.array: normalized feature values
"""
for ic in range(self.data_shape[0]):
feature[ic] = (feature[ic] - self.feature_mean[ic]
) / self.feature_std[ic]
return feature
[docs] def _clip_feature(self, feature):
"""Clip the value of the features at +/- mean + clip_factor * std.
Args:
feature (np.array): raw feature values
Returns:
np.array: clipped feature values
"""
w = self.clip_factor
for ic in range(self.data_shape[0]):
if len(feature[ic]) > 0:
minv = self.feature_mean[ic] - w * self.feature_std[ic]
maxv = self.feature_mean[ic] + w * self.feature_std[ic]
if minv != maxv:
feature[ic] = np.clip(feature[ic], minv, maxv)
#feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv)
return feature
[docs] @staticmethod
def _mad_based_outliers(points, minv, maxv, thresh=3.5):
"""Mean absolute deviation based outlier detection.
(Experimental).
Args:
points (np.array): raw input data
minv (float): Minimum (negative) value requested
maxv (float): Maximum (positive) value requested
thresh (float, optional): Threshold for data detection
Returns:
TYPE: data where outliers were replaced by min/max values
"""
median = np.median(points)
diff = np.sqrt((points - median)**2)
med_abs_deviation = np.median(diff)
if med_abs_deviation == 0:
return points
modified_z_score = 0.6745 * diff / med_abs_deviation
mask_outliers = modified_z_score > thresh
mask_max = np.abs(points - maxv) < np.abs(points - minv)
mask_min = np.abs(points - maxv) > np.abs(points - minv)
points[mask_max * mask_outliers] = maxv
points[mask_min * mask_outliers] = minv
return points
[docs] def load_one_molecule(self, fname, mol=None):
"""Load the feature/target of a single molecule.
Args:
fname (str): hdf5 file name
mol (None or str, optional): name of the complex in the hdf5
Returns:
np.array,float: features, targets
"""
outtype = 'float32'
fh5 = h5py.File(fname, 'r')
if mol is None:
mol = list(fh5.keys())[0]
# get the mol
mol_data = fh5.get(mol)
# xue:
if 'mapped_features' not in mol_data.keys():
logger.error(f"xue: Error: mol: {mol} in {fname} does not have mapped_features ")
fh5.close()
sys.exit()
# get the features
feature = []
for feat_type, feat_names in self.select_feature.items():
# see if the feature exists
if 'mapped_features/' + feat_type in mol_data.keys():
feat_dict = mol_data.get(
'mapped_features/' + feat_type)
else:
logger.error(
f'Feature type {feat_type} not found in file {fname} '
f'for molecule {mol}.\n'
f'Possible feature types are:\n\t' +
'\n\t'.join(
list(mol_data['mapped_features'].keys()))
)
raise ValueError(feat_type, ' not supported')
# loop through all the desired feat names
for name in feat_names:
# extract the group
try:
data = feat_dict[name]
except KeyError:
logger.error(
f'Feature {name} not found in file {fname} for mol '
f'{mol} and feature type {feat_type}.\n'
f'Possible feature are:\n\t' +
'\n\t'.join(list(
mol_data['mapped_features/' +
feat_type].keys()
))
)
# check its sparse attribute
# if true get a FLAN
# if flase direct import
if data.attrs['sparse']:
mat = sparse.FLANgrid(sparse=True,
index=data['index'][:],
value=data['value'][:],
shape=self._grid_shape).to_dense()
else:
mat = data['value'][:]
# append to the list of features
feature.append(mat)
# get the target value
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
target = None
logger.warning(f'No target value for: {fname} - not required for the test set')
# close
fh5.close()
# make sure all the feature have exact same type
# if they don't collate_fn in the creation of the minibatch will fail.
# Note returning torch.FloatTensor makes each epoch twice longer ...
return (np.array(feature).astype(outtype),
np.array([target]).astype(outtype))
[docs] def map_one_molecule(self, fname, mol=None, angle=None, axis=None):
"""Map the feature and load feature/target of a single molecule.
Args:
fname (str): hdf5 file name
mol (None or str, optional): name of the complex in the hdf5
Returns:
np.array,float: features, targets
"""
outtype = 'float32'
fh5 = h5py.File(fname, 'r')
if mol is None:
mol = list(fh5.keys())[0]
# get the mol
mol_data = fh5.get(mol)
grid, npts = self.get_grid(mol_data)
# get the features
feature = []
for feat_type, feat_names in self.select_feature.items():
if feat_type == 'AtomicDensities':
densities = self.map_atomic_densities(
feat_names, mol_data, grid, npts, angle, axis)
feature += densities
elif feat_type == 'Features':
data = self.map_feature(
feat_names, mol_data, grid, npts, angle, axis)
feature += data
# get the target value
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
target = None
logger.warning(f'No target value for: {fname} - not required for the test set')
# close
fh5.close()
# make sure all the feature have exact same type
# if they don't collate_fn in the creation of the minibatch will fail.
# Note returning torch.FloatTensor makes each epoch twice longer ...
return (np.array(feature).astype(outtype),
np.array([target]).astype(outtype))
[docs] @staticmethod
def convert2d(feature, proj2d):
"""Convert the 3D volumetric feature to a 2D planar data set.
proj2d specifies the dimension that we want to consider as channel
for example for proj2d = 0 the 2D images are in the yz plane and
the stack along the x dimension is considered as extra channels
Args:
feature (np.array): raw features
proj2d (int): projection
Returns:
np.array: projected features
"""
nc, nx, ny, nz = feature.shape
if proj2d == 0:
feature = feature.reshape(-1, 1, ny, nz).squeeze()
elif proj2d == 1:
feature = feature.reshape(-1, nx, 1, nz).squeeze()
elif proj2d == 2:
feature = feature.reshape(-1, nx, ny, 1).squeeze()
return feature
[docs] @staticmethod
def make_feature_pair(feature, op):
"""Pair the features of both chains.
Args:
feature (np.array): raw features
op (callable): function to combine the features
Returns:
np.array: combined features
Raises:
ValueError: if op is not callable
"""
if not callable(op):
raise ValueError('Operation not callable', op)
nFeat = len(feature)
pair_indexes = list(
np.arange(nFeat).reshape(int(nFeat / 2), 2))
outtype = feature.dtype
new_feat = []
for ind in pair_indexes:
new_feat.append(
op(feature[ind[0], ...], feature[ind[1], ...]))
return np.array(new_feat).astype(outtype)
[docs] def get_grid(self, mol_data):
"""Get meshed grids and number of pointgs
Args:
mol_data(h5 group): HDF5 moleucle group
Raises:
ValueError: Grid points not found in mol_data.
Returns:
tuple, tuple: meshgrid, npts
"""
if self.grid_info is None:
try:
x = mol_data['grid_points/x'][()]
y = mol_data['grid_points/y'][()]
z = mol_data['grid_points/z'][()]
except BaseException:
raise ValueError(
"Grid points not found in the data file")
else:
center = mol_data['grid_points/center'][()]
npts = np.array(self.grid_info['number_of_points'])
res = np.array(self.grid_info['resolution'])
halfdim = 0.5 * (npts * res)
low_lim = center - halfdim
hgh_lim = low_lim + res * (npts - 1)
x = np.linspace(low_lim[0], hgh_lim[0], npts[0])
y = np.linspace(low_lim[1], hgh_lim[1], npts[1])
z = np.linspace(low_lim[2], hgh_lim[2], npts[2])
# there is stil something strange
# with the ordering of the grid
# also noted in GridTools define_grid_points()
y, x, z = np.meshgrid(y, x, z)
grid = (x, y, z)
npts = (len(x), len(y), len(z))
return grid, npts
[docs] def map_atomic_densities(
self, feat_names, mol_data, grid, npts, angle, axis):
"""Map atomic densities.
Args:
feat_names(dict): Element type and vdw radius
mol_data(h5 group): HDF5 molecule group
grid(tuple): mesh grid of x,y,z
npts(tuple): number of points on axis x,y,z
angle(float): rotation angle
axis(list): rotation axis
Returns:
list: atomic densities of each atom type on each chain
"""
sql = pdb2sql.interface(mol_data['complex'][()])
index = sql.get_contact_atoms(chain1=self.chain1, chain2=self.chain2)
if angle is not None:
center = [np.mean(g) for g in grid]
densities = []
for elementtype, vdw_rad in feat_names.items():
# get pos of the contact atoms of correct type
xyzA = np.array(sql.get(
'x,y,z', rowID=index[self.chain1], element=elementtype))
xyzB = np.array(sql.get(
'x,y,z', rowID=index[self.chain2], element=elementtype))
# rotate if necessary
if angle is not None:
if xyzA != np.array([]):
xyzA = pdb2sql.transform.rot_xyz_around_axis(
xyzA, axis, angle, center)
if xyzB != np.array([]):
xyzB = pdb2sql.transform.rot_xyz_around_axis(
xyzB, axis, angle, center)
# init the grid
atdensA = np.zeros(npts)
atdensB = np.zeros(npts)
# run on the atoms
for pos in xyzA:
atdensA += self._densgrid(pos, vdw_rad, grid, npts)
# run on the atoms
for pos in xyzB:
atdensB += self._densgrid(pos, vdw_rad, grid, npts)
densities += [atdensA, atdensB]
sql._close()
return densities
[docs] @staticmethod
def _densgrid(center, vdw_radius, grid, npts):
"""Function to map individual atomic density on the grid.
The formula is equation (1) of the Koes paper
Protein-Ligand Scoring with Convolutional NN Arxiv:1612.02751v1
Args:
center (list(float)): position of the atoms
vdw_radius (float): vdw radius of the atom
Returns:
TYPE: np.array (mapped density)
"""
x0, y0, z0 = center
dd = np.sqrt((grid[0] - x0)**2 +
(grid[1] - y0)**2 + (grid[2] - z0)**2)
dgrid = np.zeros(npts)
dgrid[dd < vdw_radius] = np.exp(
-2 * dd[dd < vdw_radius]**2 / vdw_radius**2)
dd_tmp = dd[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)]
dgrid[(dd >= vdw_radius) & (dd < 1.5 * vdw_radius)] = (
4. / np.e**2 / vdw_radius**2 * dd_tmp**2) - (
12. / np.e**2 / vdw_radius * dd_tmp) + 9. / np.e**2
return dgrid
[docs] def map_feature(self, feat_names, mol_data, grid, npts, angle, axis):
__vectorize__ = False
if angle is not None:
center = [np.mean(g) for g in grid]
if __vectorize__:
pfunc = partial(self._featgrid, grid=grid, npts=npts)
vmap = np.vectorize(pfunc, signature='(n),()->(p,p,p)')
feat = []
for name in feat_names:
tmp_feat_ser = [np.zeros(npts), np.zeros(npts)]
tmp_feat_vect = [np.zeros(npts), np.zeros(npts)]
data = np.array(mol_data['features/' + name][()])
if data.shape[0]==0:
logger.warning(f'No {name} retrieved at the protein/protein interface')
else:
chain = data[:, 0]
pos = data[:, 1:4]
feat_value = data[:, 4]
if angle is not None:
pos = pdb2sql.transform.rot_xyz_around_axis(
pos, axis, angle, center)
if __vectorize__ or __vectorize__ == 'both':
for chainID in [0, 1]:
tmp_feat_vect[chainID] = np.sum(
vmap(pos[chain == chainID, :],
feat_value[chain == chainID]),
0)
if not __vectorize__ or __vectorize__ == 'both':
for chainID, xyz, val in zip(chain, pos, feat_value):
tmp_feat_ser[int(chainID)] += \
self._featgrid(xyz, val, grid, npts)
if __vectorize__ == 'both':
assert np.allclose(tmp_feat_ser, tmp_feat_vect)
if __vectorize__:
feat += tmp_feat_vect
else:
feat += tmp_feat_ser
return feat
[docs] @staticmethod
def _featgrid(center, value, grid, npts):
"""Map an individual feature (atomic or residue) on the grid.
Args:
center (list(float)): position of the feature center
value (float): value of the feature
type_ (str, optional): method to map
Returns:
np.array: Mapped feature
Raises:
ValueError: Description
"""
# shortcut for th center
x0, y0, z0 = center
sigma = np.sqrt(1. / 2)
beta = 0.5 / (sigma**2)
cutoff = 5. * beta
dd = np.sqrt((grid[0] - x0)**2 +
(grid[1] - y0)**2 + (grid[2] - z0)**2)
dd[dd < cutoff] = value * np.exp(-beta * dd[dd < cutoff])
dd[dd > cutoff] = 0
#dgrid = np.zeros(npts)
#dgrid[dd<cutoff] = value*np.exp(-beta*dd[dd<cutoff])
# print(np.allclose(dgrid,dd))
return dd