Source code for deeprank.generate.DataGenerator

import importlib
import copy
import os
import re
import sys
import warnings
from collections import OrderedDict

import h5py
import numpy as np

import deeprank
from deeprank import config
from deeprank.config import logger
from deeprank.generate import GridTools as gt
import pdb2sql
from pdb2sql.align import align as align_along_axis
from pdb2sql.align import align_interface

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(x):
        return x

try:
    from pycuda import driver, compiler, gpuarray, tools
    import pycuda.autoinit
except ImportError:
    pass


[docs]def _printif(string, cond): return print(string) if cond else None
[docs]class DataGenerator(object): def __init__(self, chain1, chain2, pdb_select=None, pdb_source=None, pdb_native=None, pssm_source=None, align=None, compute_targets=None, compute_features=None, data_augmentation=None, hdf5='database.h5', mpi_comm=None): """Generate the data (features/targets/maps) required for deeprank. Args: chain1 (str): First chain ID chain2 (str): Second chain ID pdb_select (list(str), optional): List of individual conformation for mapping pdb_source (list(str), optional): List of folders where to find the pdbs for mapping pdb_native (list(str), optional): List of folders where to find the native comformations, nust set it if having targets to compute in parameter "compute_targets". pssm_source (list(str), optional): List of folders where to find the PSSM files align (dict, optional): Dicitionary to align the compexes, e.g. align = {"selection":{"chainID":["A","B"]},"axis":"z"}} e.g. align = {"selection":"interface","plane":"xy"} if "selection" is not specified the entire complex is used for alignement compute_targets (list(str), optional): List of python files computing the targets, "pdb_native" must be set if having targets to compute. compute_features (list(str), optional): List of python files computing the features data_augmentation (int, optional): Number of rotation performed one each complex hdf5 (str, optional): name of the hdf5 file where the data is saved, default to 'database.h5' mpi_comm (MPI_COMM): MPI COMMUNICATOR Raises: NotADirectoryError: if the source are not found Example: >>> from deeprank.generate import * >>> # sources to assemble the data base >>> pdb_source = ['./1AK4/decoys/'] >>> pdb_native = ['./1AK4/native/'] >>> pssm_source = ['./1AK4/pssm_new/'] >>> h5file = '1ak4.hdf5' >>> >>> #init the data assembler >>> database = DataGenerator(chain1='C', >>> chain2='D', >>> pdb_source=pdb_source, >>> pdb_native=pdb_native, >>> pssm_source=pssm_source, >>> data_augmentation=None, >>> compute_targets=['deeprank.targets.dockQ'], >>> compute_features=['deeprank.features.AtomicFeature', >>> 'deeprank.features.PSSM_IC', >>> 'deeprank.features.BSA'], >>> hdf5=h5file) """ self.chain1 = chain1 self.chain2 = chain2 self.pdb_select = pdb_select or [] self.pdb_source = pdb_source or [] self.pdb_native = pdb_native or [] self.pssm_source = pssm_source self.align = align if self.pssm_source is not None: config.PATH_PSSM_SOURCE = self.pssm_source self.compute_targets = compute_targets self.compute_features = compute_features self.data_augmentation = data_augmentation self.hdf5 = hdf5 self.mpi_comm = mpi_comm # set helper attributes self.all_pdb = [] self.all_native = [] self.pdb_path = [] self.feature_error = [] self.grid_error = [] self.map_error = [] self.logger = logger # handle the pdb_select if not isinstance(self.pdb_select, list): self.pdb_select = [self.pdb_select] # handle the sources if not isinstance(self.pdb_source, list): self.pdb_source = [self.pdb_source] # handle pssm source pssm_features = ('deeprank.features.FullPSSM', 'deeprank.features.PSSM_IC') if self.compute_features and \ set.intersection(set(pssm_features), set(self.compute_features)): if config.PATH_PSSM_SOURCE is None: raise ValueError( 'You must provide "pssm_source" to compute PSSM features.') # get all the conformation path for src in self.pdb_source: if os.path.isdir(src): self.all_pdb += [os.path.join(src, fname) for fname in os.listdir(src) if fname.endswith('.pdb')] elif os.path.isfile(src): self.all_pdb.append(src) # handle the native if not isinstance(self.pdb_native, list): self.pdb_native = [self.pdb_native] for src in self.pdb_native: if os.path.isdir(src): self.all_native += [os.path.join(src, fname) for fname in os.listdir(src)] if os.path.isfile(src): self.all_native.append(src) # filter the cplx if required if self.pdb_select: for i in self.pdb_select: self.pdb_path += list(filter(lambda x: i in x, self.all_pdb)) else: self.pdb_path = self.all_pdb # ==================================================================================== # # CREATE THE DATABASE ALL AT ONCE IF ALL OPTIONS ARE GIVEN # # ====================================================================================
[docs] def create_database( self, verbose=False, remove_error=True, prog_bar=False, contact_distance=8.5, random_seed=None): """Create the hdf5 file architecture and compute the features/targets. Args: verbose (bool, optional): Print creation details remove_error (bool, optional): remove the groups that errored prog_bar (bool, optional): use tqdm contact_distance (float): contact distance cutoff, defaults to 8.5Å random_seed (int): random seed for getting rotation axis and angle Raises: ValueError: If creation of the group errored. Example: >>> # sources to assemble the data base >>> pdb_source = ['./1AK4/decoys/'] >>> pdb_native = ['./1AK4/native/'] >>> pssm_source = ['./1AK4/pssm_new/'] >>> h5file = '1ak4.hdf5' >>> >>> #init the data assembler >>> database = DataGenerator(chain1='C', >>> chain2='D', >>> pdb_source=pdb_source, >>> pdb_native=pdb_native, >>> pssm_source=pssm_source, >>> data_augmentation=None, >>> compute_targets = ['deeprank.targets.dockQ'], >>> compute_features = ['deeprank.features.AtomicFeature', >>> 'deeprank.features.PSSM_IC', >>> 'deeprank.features.BSA'], >>> hdf5=h5file) >>> >>> #create new files >>> database.create_database(prog_bar=True) """ # check decoy pdb files if not self.pdb_path: raise ValueError(f"Decoy pdb files not found. Check class " f"parameters 'pdb_source' and 'pdb_select'.") # deals with the parallelization self.local_pdbs = self.pdb_path if self.mpi_comm is not None: rank = self.mpi_comm.Get_rank() size = self.mpi_comm.Get_size() else: size = 1 if size > 1: if rank == 0: pdbs = [self.pdb_path[i::size] for i in range(size)] self.local_pdbs = pdbs[0] # send to other procs for iP in range(1, size): self.mpi_comm.send(pdbs[iP], dest=iP, tag=11) else: # receive procs self.local_pdbs = self.mpi_comm.recv(source=0, tag=11) # change hdf5 name h5path, h5name = os.path.split(self.hdf5) self.hdf5 = os.path.join(h5path, f"{rank:03d}_{h5name}") # open the file self.f5 = h5py.File(self.hdf5, 'w') # set metadata to hdf5 file self.f5.attrs['DeepRank_version'] = deeprank.__version__ self.f5.attrs['pdb_source'] = [ os.path.abspath(f) for f in self.pdb_source] self.f5.attrs['pdb_native'] = [ os.path.abspath(f) for f in self.pdb_native] self.f5.attrs['pssm_source'] = os.path.abspath( self.pssm_source) if self.compute_features is not None: self.f5.attrs['features'] = self.compute_features if self.compute_targets is not None: self.f5.attrs['targets'] = self.compute_targets ################################################## # Start generating HDF5 database ################################################## self.logger.info( f'\n# Start creating HDF5 database: {self.hdf5}') # get the local progress bar desc = '{:25s}'.format('Creating database') cplx_tqdm = tqdm(self.local_pdbs, desc=desc, disable=not prog_bar) for cplx in cplx_tqdm: cplx_tqdm.set_postfix(mol=os.path.basename(cplx)) self.logger.info(f'\nProcessing PDB file: {cplx}') # names of the molecule mol_name = os.path.splitext(os.path.basename(cplx))[0] mol_aug_name_list = [] try: ################################################ # get the pdbs of the conformation and its ref # for the original data (not augmetned one) ################################################ if verbose: self.logger.info( f'\nMolecule: {mol_name}.' f'\nStart generating top HDF5 group "{mol_name}"...' f'\n{"":4s}Reading PDB data into database...') # get the bare name of the molecule # and define the name of the native # i.e. 1AK4_100w -> 1AK4 bare_mol_name = mol_name.split('_')[0] ref_name = bare_mol_name + '.pdb' # check if we have a decoy or native # and find the reference if mol_name == bare_mol_name: ref = cplx else: if len(self.all_native) > 0: ref = list( filter(lambda x: ref_name in x, self.all_native)) if len(ref) == 0: raise ValueError('Native not found') else: if len(ref) > 1: warnings.warn( f'Multiple native reference found, here used {ref[0]}') ref = ref[0] if ref == '': ref = None else: ref = None # crete a subgroup for the molecule molgrp = self.f5.require_group(mol_name) molgrp.attrs['type'] = 'molecule' # add the ref and the complex self._add_pdb(molgrp, cplx, 'complex') if ref is not None: self._add_pdb(molgrp, ref, 'native') if verbose: self.logger.info( f'{"":4s}Generated subgroup "complex"' f' to store pdb data of the current model.') if ref: self.logger.info( f'{"":4s}Generated subgroup "native"' f' to store pdb data of the reference molecule.') ################################################ # add the features ################################################ feature_error_flag = False # when False: success; when True: failed if self.compute_features is not None: if verbose: self.logger.info( f'{"":4s}Calculating features...') molgrp.require_group('features') molgrp.require_group('features_raw') feature_error_flag = self._compute_features(self.compute_features, molgrp['complex'][( )], molgrp['features'], molgrp['features_raw'], self.chain1, self.chain2, self.logger) if feature_error_flag: self.feature_error += [mol_name] # ignore the targets/grid/augmentation computation # and directly go to next molecule. Remove errored # molecule later. # Otherwise, keep computing and report errored mol. if remove_error: continue if verbose: if not feature_error_flag or not remove_error: self.logger.info( f'\n{"":4s}Generated subgroup "features"' f' to store xyz-based feature values.' f'{"":4s}Generated subgroup "features_raw"' f' to store human read feature values') ################################################ # add the targets ################################################ if self.compute_targets is not None: if verbose: self.logger.info( f'{"":4s}Calculating targets...') molgrp.require_group('targets') self._compute_targets(self.compute_targets, molgrp['complex'][()], molgrp['targets']) if verbose: self.logger.info( f'{"":4s}Generated subgroup "targets" ' f'to store targets, such as BIN_CLASS, dockQ, etc.') ################################################ # add the box center ################################################ if verbose: self.logger.info( f'{"":4s}Calculating grid box center...') grid_error_flag = False molgrp.require_group('grid_points') try: center = self._get_grid_center( molgrp['complex'][()], contact_distance) molgrp['grid_points'].create_dataset( 'center', data=center) if verbose: self.logger.info( f'{"":4s}Generated subgroup "grid_points"' f' to store grid box center.') except ValueError as ex: grid_error_flag = True self.grid_error += [mol_name] self.logger.exception(ex) if remove_error: continue ################################################ # DATA AUGMENTATION ################################################ # GET ALL THE NAMES if self.data_augmentation is not None: mol_aug_name_list = [ mol_name + '_r%03d' % (idir + 1) for idir in range( self.data_augmentation)] else: mol_aug_name_list = [] if verbose and mol_aug_name_list: self.logger.info( f'{"":2s}Start augmenting data' f' with {self.data_augmentation} times...') # loop over the complexes for mol_aug_name in mol_aug_name_list: # crete a subgroup for the molecule molgrp = self.f5.require_group(mol_aug_name) molgrp.attrs['type'] = 'molecule' # copy the ref into it if ref is not None: self._add_pdb(molgrp, ref, 'native') # get the rotation axis and angle if self.align is None: axis, angle = pdb2sql.transform.get_rot_axis_angle( random_seed) else: axis, angle = self._get_aligned_rotation_axis_angle(random_seed, self.align) # create the new pdb and get molecule center # molecule center is the origin of rotation) mol_center = self._add_aug_pdb( molgrp, cplx, 'complex', axis, angle) # copy the targets/features if 'targets' in self.f5[mol_name]: self.f5.copy(mol_name + '/targets/', molgrp) self.f5.copy(mol_name + '/features/', molgrp) # rotate the feature self._rotate_feature( molgrp, axis, angle, mol_center) # grid center used to create grid box molgrp.require_group('grid_points') center = pdb2sql.transform.rot_xyz_around_axis( self.f5[mol_name + '/grid_points/center'], axis, angle, mol_center) molgrp['grid_points'].create_dataset( 'center', data=center) # store the rotation axis/angl/center as attriutes # in case we need them later molgrp.attrs['axis'] = axis molgrp.attrs['angle'] = angle molgrp.attrs['center'] = mol_center # cache aug mols if original mol has errored features if feature_error_flag: self.feature_error += mol_aug_name_list if grid_error_flag: self.grid_error += mol_aug_name_list if verbose and mol_aug_name_list: self.logger.info( f'{"":2s}Completed data augmentation' f' and generated top HDF5 groups, e.g. {mol_aug_name}.') ################################################ # Successul message ################################################ if verbose: self.logger.info( f'\nSuccessfully generated top HDF5 group "{mol_name}".\n') # all other errors except BaseException: raise ################################################## # Post processing ################################################## # Remove errored molecules errored_mol = list(set(self.feature_error + self.grid_error)) if errored_mol: if remove_error: for mol in errored_mol: del self.f5[mol] if self.feature_error: self.logger.info( f'Molecules with errored features are removed:' f'\n{self.feature_error}') if self.grid_error: self.logger.info( f'Molecules with errored grid points are removed:' f'\n{self.grid_error}') else: if self.feature_error: self.logger.warning( f'The following molecules have errored features:' f'\n{self.feature_error}') if self.grid_error: self.logger.warning( f'The following molecules have errored grid points:' f'\n{self.grid_error}') # close the file self.f5.close() self.logger.info( f'\n# Successfully created database: {self.hdf5}\n')
[docs] def aug_data(self, augmentation, keep_existing_aug=True, random_seed=None): """Augment exiting original PDB data and features. Args: augmentation(int): Times of augmentation keep_existing_aug (bool, optional): Keep existing augmentated data. If False, existing aug will be removed. Defaults to True. Examples: >>> database = DataGenerator(h5='database.h5') >>> database.aug_data(augmentation=3, append=True) >>> grid_info = { >>> 'number_of_points': [30,30,30], >>> 'resolution': [1.,1.,1.], >>> 'atomic_densities': {'C':1.7, 'N':1.55, 'O':1.52, 'S':1.8}, >>> } >>> database.map_features(grid_info) """ # check if file exists if not os.path.isfile(self.hdf5): raise FileNotFoundError( 'File %s does not exists' % self.hdf5) # get the folder names f5 = h5py.File(self.hdf5, 'a') fnames = f5.keys() # get the non rotated ones fnames_original = list( filter(lambda x: not re.search(r'_r\d+$', x), fnames)) # get the rotated ones fnames_augmented = list( filter(lambda x: re.search(r'_r\d+$', x), fnames)) aug_id_start = 0 if keep_existing_aug: exiting_augs = list( filter(lambda x: re.search(fnames_original[0] + r'_r\d+$', x), fnames_augmented)) aug_id_start += len(exiting_augs) else: for i in fnames_augmented: del f5[i] self.logger.info( f'{"":s}\n# Start augmenting data' f' with {augmentation} times...') # GET ALL THE NAMES for mol_name in fnames_original: mol_aug_name_list = [ mol_name + '_r%03d' % (idir + 1) for idir in range(aug_id_start, aug_id_start + augmentation)] # loop over the complexes for mol_aug_name in mol_aug_name_list: # crete a subgroup for the molecule molgrp = f5.require_group(mol_aug_name) molgrp.attrs['type'] = 'molecule' # copy the ref into it if 'native' in f5[mol_name]: f5.copy(mol_name + '/native', molgrp) # get the rotation axis and angle if self.align is None: axis, angle = pdb2sql.transform.get_rot_axis_angle( random_seed) else: axis, angle = self._get_aligned_rotation_axis_angle(random_seed, self.align) # create the new pdb and get molecule center # molecule center is the origin of rotation) mol_center = self._add_aug_pdb( molgrp, f5[mol_name + '/complex'][()], 'complex', axis, angle) # copy the targets/features if 'targets' in f5[mol_name]: f5.copy(mol_name + '/targets/', molgrp) f5.copy(mol_name + '/features/', molgrp) # rotate the feature self._rotate_feature(molgrp, axis, angle, mol_center) # grid center used to create grid box molgrp.require_group('grid_points') center = pdb2sql.transform.rot_xyz_around_axis( f5[mol_name + '/grid_points/center'], axis, angle, mol_center) molgrp['grid_points'].create_dataset( 'center', data=center) # store the rotation axis/angl/center as attriutes # in case we need them later molgrp.attrs['axis'] = axis molgrp.attrs['angle'] = angle molgrp.attrs['center'] = mol_center f5.close() self.logger.info( f'\n# Successfully augmented data in {self.hdf5}')
# ==================================================================================== # # ADD FEATURES TO AN EXISTING DATASET # # ====================================================================================
[docs] def add_feature(self, remove_error=True, prog_bar=True): """Add a feature to an existing hdf5 file. Args: remove_error (bool): remove errored molecule prog_bar (bool, optional): use tqdm Example: >>> h5file = '1ak4.hdf5' >>> >>> #init the data assembler >>> database = DataGenerator(compute_features = ['deeprank.features.ResidueDensity'], >>> hdf5=h5file) >>> >>> database.add_feature(remove_error=True, prog_bar=True) """ # check if file exists if not os.path.isfile(self.hdf5): raise FileNotFoundError( 'File %s does not exists' % self.hdf5) # get the folder names f5 = h5py.File(self.hdf5, 'a') fnames = f5.keys() # get the non rotated ones fnames_original = list( filter(lambda x: not re.search(r'_r\d+$', x), fnames)) # get the rotated ones fnames_augmented = list( filter(lambda x: re.search(r'_r\d+$', x), fnames)) # check feature_error if not self.feature_error: self.feature_error = [] # computes the features of the original desc = '{:25s}'.format('Add features') for cplx_name in tqdm( fnames_original, desc=desc, ncols=100, disable=not prog_bar): # molgrp molgrp = f5[cplx_name] error_flag = False if self.compute_features is not None: # the internal features molgrp.require_group('features') molgrp.require_group('features_raw') error_flag = self._compute_features(self.compute_features, molgrp['complex'][()], molgrp['features'], molgrp['features_raw'], self.chain1, self.chain2, self.logger) if error_flag: self.feature_error += [cplx_name] # copy the data from the original to the augmented for cplx_name in fnames_augmented: # group of the molecule aug_molgrp = f5[cplx_name] # get the source group mol_name = re.split(r'_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # get the rotation parameters axis = aug_molgrp.attrs['axis'] angle = aug_molgrp.attrs['angle'] center = aug_molgrp.attrs['center'] # copy the features to the augmented for k in molgrp['features']: if k not in aug_molgrp['features']: # copy data = src_molgrp['features/' + k][()] aug_molgrp.require_group('features') aug_molgrp.create_dataset( "features/" + k, data=data) # rotate self._rotate_feature( aug_molgrp, axis, angle, center, feat_name=[k]) # find errored augmented molecules tmp_aug_error = [] for mol in self.feature_error: tmp_aug_error += list(filter(lambda x: mol in x, fnames_augmented)) self.feature_error += tmp_aug_error # Remove errored molecules if self.feature_error: if remove_error: for mol in self.feature_error: del f5[mol] self.logger.info( f'Molecules with errored features are removed:\n' f'{self.feature_error}') else: self.logger.warning( f"The following molecules has errored features:\n" f'{self.feature_error}') # close the file f5.close()
# ==================================================================================== # # ADD TARGETS TO AN EXISTING DATASET # # ====================================================================================
[docs] def add_unique_target(self, targdict): """Add identical targets for all the complexes in the datafile. This is usefull if you want to add the binary class of all the complexes created from decoys or natives Args: targdict (dict): Example: {'DOCKQ':1.0} >>> database = DataGenerator(hdf5='1ak4.hdf5') >>> database.add_unique_target({'DOCKQ':1.0}) """ # check if file exists if not os.path.isfile(self.hdf5): raise FileNotFoundError( 'File %s does not exists' % self.hdf5) f5 = h5py.File(self.hdf5, 'a') for mol in list(f5.keys()): targrp = f5[mol].require_group('targets') for name, value in targdict.items(): targrp.create_dataset(name, data=np.array([value])) f5.close()
[docs] def add_target(self, prog_bar=False): """Add a target to an existing hdf5 file. Args: prog_bar (bool, optional): Use tqdm Example: >>> h5file = '1ak4.hdf5' >>> >>> #init the data assembler >>> database = DataGenerator(compute_targets =['deeprank.targets.binary_class'], >>> hdf5=h5file) >>> >>> database.add_target(prog_bar=True) """ # check if file exists if not os.path.isfile(self.hdf5): raise FileNotFoundError( 'File %s does not exists' % self.hdf5) # name of the hdf5 file f5 = h5py.File(self.hdf5, 'a') # get the folder names fnames = f5.keys() # get the non rotated ones fnames_original = list( filter(lambda x: not re.search(r'_r\d+$', x), fnames)) fnames_augmented = list( filter(lambda x: re.search(r'_r\d+$', x), fnames)) # compute the targets of the original desc = '{:25s}'.format('Add targets') for cplx_name in tqdm(fnames_original, desc=desc, ncols=100, disable=not prog_bar): # group of the molecule molgrp = f5[cplx_name] # add the targets if self.compute_targets is not None: molgrp.require_group('targets') self._compute_targets(self.compute_targets, molgrp['complex'][()], molgrp['targets']) # copy the targets of the original to the rotated for cplx_name in fnames_augmented: # group of the molecule aug_molgrp = f5[cplx_name] # get the source group mol_name = re.split(r'_r\d+', molgrp.name)[0] src_molgrp = f5[mol_name] # copy the targets to the augmented for k in molgrp['targets']: if k not in aug_molgrp['targets']: data = src_molgrp['targets/' + k][()] aug_molgrp.require_group('targets') aug_molgrp.create_dataset( "targets/" + k, data=data) # close the file f5.close()
[docs] def realign_complexes(self, align, compute_features=None, pssm_source=None): """Align all the complexes already present in the HDF5. Arguments: align {dict} -- alignement dictionary (see __init__) Keyword Arguments: compute_features {list} -- list of features to be computed if None computes the features specified in the attrs['features'] of the file (if present) pssm_source {str} -- path of the pssm files. If None the source specfied in the attrs['pssm_source'] will be used (if present) (default: {None}) Raises: ValueError: If no PSSM detected Example: >>> database = DataGenerator(hdf5='1ak4.hdf5') >>> # if comute_features and pssm_source are not specified >>> # the values in hdf5.attrs['features'] and hdf5.attrs['pssm_source'] will be used >>> database.realign_complex(align={'axis':'x'}, >>> compute_features['deeprank.features.X'], >>> pssm_source='./1ak4_pssm/') """ f5 = h5py.File(self.hdf5, 'a') mol_names = f5.keys() self.logger.info( f'\n# Start aligning the HDF5 database: {self.hdf5}') # deal with the features if self.compute_features is None: if compute_features is None: if 'features' in f5.attrs: self.compute_features = list(f5.attrs['features']) else: self.compute_features = compute_features # deal with the pssm source if self.pssm_source is not None: config.PATH_PSSM_SOURCE = self.pssm_source elif pssm_source is not None: config.PATH_PSSM_SOURCE = pssm_source elif 'pssm_source' in f5.attrs: config.PATH_PSSM_SOURCE = f5.attrs['pssm_source'] else: raise ValueError('No pssm source detected') # loop over the complexes desc = '{:25s}'.format('Add features') for mol in tqdm(mol_names, desc=desc, ncols=100): # align the pdb molgrp = f5[mol] pdb = molgrp['complex'][()] sqldb = self._get_aligned_sqldb(pdb, align) data = sqldb.sql2pdb() data = np.array(data).astype('|S78') molgrp['complex'][...] = data # remove prexisting features old_dir = ['features', 'features_raw', 'mapped_features'] for od in old_dir: if od in molgrp: del molgrp[od] # the internal features molgrp.require_group('features') molgrp.require_group('features_raw') # compute features error_flag = self._compute_features(self.compute_features, molgrp['complex'][()], molgrp['features'], molgrp['features_raw'], self.chain1, self.chain2, self.logger) f5.close()
# ==================================================================================== # # PRECOMPUTE TEH GRID POINTS # # ====================================================================================
[docs] def _get_grid_center(self, pdb, contact_distance): sqldb = pdb2sql.interface(pdb) contact_atoms = sqldb.get_contact_atoms(cutoff=contact_distance, chain1=self.chain1, chain2=self.chain2) tmp = [] for i in contact_atoms.values(): tmp.extend(i) contact_atoms = list(set(tmp)) center_contact = np.mean( np.array(sqldb.get('x,y,z', rowID=contact_atoms)), 0) sqldb._close() return center_contact
[docs] def precompute_grid(self, grid_info, contact_distance=8.5, prog_bar=False, time=False, try_sparse=True): # name of the hdf5 file f5 = h5py.File(self.hdf5, 'a') # check all the input PDB files mol_names = f5.keys() # get the local progress bar desc = '{:25s}'.format('Precompute grid points') mol_tqdm = tqdm(mol_names, desc=desc, disable=not prog_bar) if not prog_bar: print(desc, ':', self.hdf5) sys.stdout.flush() # loop over the data files for mol in mol_tqdm: mol_tqdm.set_postfix(mol=mol) # compute the data we want on the grid gt.GridTools(molgrp=f5[mol], chain1=self.chain1, chain2=self.chain2, number_of_points=grid_info['number_of_points'], resolution=grid_info['resolution'], contact_distance=contact_distance, time=time, prog_bar=prog_bar, try_sparse=try_sparse) f5.close()
# ==================================================================================== # # MAP THE FEATURES TO THE GRID # # ====================================================================================
[docs] def map_features(self, grid_info={}, cuda=False, gpu_block=None, cuda_kernel='kernel_map.c', cuda_func_name='gaussian', try_sparse=True, reset=False, use_tmpdir=False, time=False, prog_bar=True, grid_prog_bar=False, remove_error=True): """Map the feature on a grid of points centered at the interface. If features to map are not given, they will be are automatically determined for each molecule. Otherwise, given features will be mapped for all molecules (i.e. existing mapped features will be recalculated). Args: grid_info (dict): Informaton for the grid. See deeprank.generate.GridTools.py for details. cuda (bool, optional): Use CUDA gpu_block (None, optional): GPU block size to be used cuda_kernel (str, optional): filename containing CUDA kernel cuda_func_name (str, optional): The name of the function in the kernel try_sparse (bool, optional): Try to save the grids as sparse format reset (bool, optional): remove grids if some are already present use_tmpdir (bool, optional): use a scratch directory time (bool, optional): time the mapping process prog_bar (bool, optional): use tqdm for each molecule grid_prog_bar (bool, optional): use tqdm for each grid remove_error (bool, optional): remove the data that errored Example: >>> #init the data assembler >>> database = DataGenerator(hdf5='1ak4.hdf5') >>> >>> # map the features >>> grid_info = { >>> 'number_of_points': [30,30,30], >>> 'resolution': [1.,1.,1.], >>> 'atomic_densities': {'C':1.7, 'N':1.55, 'O':1.52, 'S':1.8}, >>> } >>> >>> database.map_features(grid_info,try_sparse=True,time=False,prog_bar=True) """ # default CUDA cuda_func = None cuda_atomic = None # disable CUDA when using MPI if self.mpi_comm is not None: if self.mpi_comm.Get_size() > 1: if cuda: self.logger.warning( 'CUDA mapping disabled when using MPI') cuda = False # name of the hdf5 file f5 = h5py.File(self.hdf5, 'a') # check all the input PDB files mol_names = f5.keys() if len(mol_names) == 0: f5.close() raise ValueError(f'No molecules found in {self.hdf5}.') ################################################################ # Check grid_info ################################################################ # fills in the grid data if not provided: default = NONE grid_info_ref = copy.deepcopy(grid_info) grinfo = ['number_of_points', 'resolution'] for gr in grinfo: if gr not in grid_info: grid_info[gr] = None # by default we do not map atomic densities if 'atomic_densities' not in grid_info: grid_info['atomic_densities'] = None # fills in the features mode if somes are missing: default = IND modes = ['atomic_densities_mode', 'feature_mode'] for m in modes: if m not in grid_info: grid_info[m] = 'ind' ################################################################ # ################################################################ # sanity check for cuda if cuda and gpu_block is None: # pragma: no cover self.logger.info( f'GPU block automatically set to 8 x 8 x 8. ' f'You can set block size with gpu_block=[n,m,k]') gpu_block = [8, 8, 8] # initialize cuda if cuda: # pragma: no cover # compile cuda module npts = grid_info['number_of_points'] res = grid_info['resolution'] module = self._compile_cuda_kernel(cuda_kernel, npts, res) # get the cuda function for the atomic/residue feature cuda_func = self._get_cuda_function( module, cuda_func_name) # get the cuda function for the atomic densties cuda_atomic_name = 'atomic_densities' cuda_atomic = self._get_cuda_function( module, cuda_atomic_name) # get the local progress bar desc = '{:25s}'.format('Map Features') mol_tqdm = tqdm(mol_names, desc=desc, disable=not prog_bar) if not prog_bar: self.logger.info(f'{desc}: {self.hdf5}') # loop over the data files for mol in mol_tqdm: mol_tqdm.set_postfix(mol=mol) # Determine which feature to map # if feature not given, then determine it for each molecule if 'feature' not in grid_info_ref: # if we havent mapped anything yet or if we reset if 'mapped_features' not in list(f5[mol].keys()) or reset: grid_info['feature'] = list( f5[mol + '/features'].keys()) # if we have already mapped stuff elif 'mapped_features' in list(f5[mol].keys()): # feature name all_feat = list(f5[mol + '/features'].keys()) # feature already mapped mapped_feat = list( f5[mol + '/mapped_features/Feature_ind'].keys()) # we select only the feture that were not mapped yet grid_info['feature'] = [] for feat_name in all_feat: if not any(map(lambda x: x.startswith(feat_name + '_'), mapped_feat)): grid_info['feature'].append(feat_name) try: # compute the data we want on the grid gt.GridTools( molgrp=f5[mol], chain1=self.chain1, chain2=self.chain2, number_of_points=grid_info['number_of_points'], resolution=grid_info['resolution'], atomic_densities=grid_info['atomic_densities'], atomic_densities_mode=grid_info['atomic_densities_mode'], feature=grid_info['feature'], feature_mode=grid_info['feature_mode'], cuda=cuda, gpu_block=gpu_block, cuda_func=cuda_func, cuda_atomic=cuda_atomic, time=time, prog_bar=grid_prog_bar, try_sparse=try_sparse) except BaseException: self.map_error.append(mol) self.logger.exception( f'Error during the mapping of {mol}') # remove the molecule with issues if self.map_error: if remove_error: for mol in self.map_error: del f5[mol] self.logger.warning( f"Molecules with errored feature mapping are removed:\n" f"{self.map_error}") else: self.logger.warning( f"The following moleclues have errored feature mapping:\n" f"{self.map_error}") # close he hdf5 file f5.close()
# ==================================================================================== # # REMOVE DATA FROM THE DATA SET # # ====================================================================================
[docs] def remove(self, feature=True, pdb=True, points=True, grid=False): """Remove data from the data set. Equivalent to the cleandata command line tool. Once the data has been removed from the file it is impossible to add new features/targets Args: feature (bool, optional): Remove the features pdb (bool, optional): Remove the pdbs points (bool, optional): remove teh grid points grid (bool, optional): remove the maps """ self.logger.debug('Remove features') # name of the hdf5 file f5 = h5py.File(self.hdf5, 'a') # get the folder names mol_names = f5.keys() for name in mol_names: mol_grp = f5[name] if feature and 'features' in mol_grp: del mol_grp['features'] del mol_grp['features_raw'] if pdb and 'complex' in mol_grp and 'native' in mol_grp: del mol_grp['complex'] del mol_grp['native'] if points and 'grid_points' in mol_grp: del mol_grp['grid_points'] if grid and 'mapped_features' in mol_grp: del mol_grp['mapped_features'] f5.close() # reclaim the space os.system('h5repack %s _tmp.h5py' % self.hdf5) os.system('mv _tmp.h5py %s' % self.hdf5)
# ==================================================================================== # # Simply tune or test the kernel # # ====================================================================================
[docs] def _tune_cuda_kernel(self, grid_info, cuda_kernel='kernel_map.c', func='gaussian'): # pragma: no cover """Tune the CUDA kernel using the kernel tuner http://benvanwerkhoven.github.io/kernel_tuner/ Args: grid_info (dict): information for the grid definition cuda_kernel (str, optional): file containing the kernel func (str, optional): function in the kernel to be used Raises: ValueError: If the tuner has not been used """ try: from kernel_tuner import tune_kernel except BaseException: print( 'Install the Kernel Tuner: \n \t\t pip install kernel_tuner') print('http://benvanwerkhoven.github.io/kernel_tuner/') # fills in the grid data if not provided: default = NONE grinfo = ['number_of_points', 'resolution'] for gr in grinfo: if gr not in grid_info: raise ValueError( '%s must be specified to tune the kernel') # define the grid nx, ny, nz = grid_info['number_of_points'] dx, dy, dz = grid_info['resolution'] lx, ly, lz = nx * dx, ny * dy, nz * dz x = np.linspace(0, lx, nx) y = np.linspace(0, ly, ny) z = np.linspace(0, lz, nz) # create the dictionary containing the tune parameters tune_params = OrderedDict() tune_params['block_size_x'] = [2, 4, 8, 16, 32] tune_params['block_size_y'] = [2, 4, 8, 16, 32] tune_params['block_size_z'] = [2, 4, 8, 16, 32] # define the final grid grid = np.zeros(grid_info['number_of_points']) # arguments of the CUDA function x0, y0, z0 = np.float32(0), np.float32(0), np.float32(0) alpha = np.float32(0) args = [alpha, x0, y0, z0, x, y, z, grid] # dimensionality problem_size = grid_info['number_of_points'] # get the kernel kernel = os.path.dirname( os.path.abspath(__file__)) + '/' + cuda_kernel kernel_code_template = open(kernel, 'r').read() npts = grid_info['number_of_points'] res = grid_info['resolution'] kernel_code = kernel_code_template % { 'nx': npts[0], 'ny': npts[1], 'nz': npts[2], 'RES': np.max(res)} tunable_kernel = self._tunable_kernel(kernel_code) # tune tune_kernel(func, tunable_kernel, problem_size, args, tune_params)
# ==================================================================================== # # Simply test the kernel # # ====================================================================================
[docs] def _test_cuda(self, grid_info, gpu_block=8, cuda_kernel='kernel_map.c', func='gaussian'): # pragma: no cover """Test the CUDA kernel. Args: grid_info (dict): Information for the grid definition gpu_block (int, optional): GPU block size to be used cuda_kernel (str, optional): File containing the kernel func (str, optional): function in the kernel to be used Raises: ValueError: If the kernel has not been installed """ from time import time # fills in the grid data if not provided: default = NONE grinfo = ['number_of_points', 'resolution'] for gr in grinfo: if gr not in grid_info: raise ValueError( '%s must be specified to tune the kernel') # get the cuda function npts = grid_info['number_of_points'] res = grid_info['resolution'] module = self._compile_cuda_kernel(cuda_kernel, npts, res) cuda_func = self._get_cuda_function(module, func) # define the grid nx, ny, nz = grid_info['number_of_points'] dx, dy, dz = grid_info['resolution'] lx, ly, lz = nx * dx, ny * dy, nz * dz # create the coordinate x = np.linspace(0, lx, nx) y = np.linspace(0, ly, ny) z = np.linspace(0, lz, nz) # book memp on the gpu x_gpu = gpuarray.to_gpu(x.astype(np.float32)) y_gpu = gpuarray.to_gpu(y.astype(np.float32)) z_gpu = gpuarray.to_gpu(z.astype(np.float32)) grid_gpu = gpuarray.zeros( grid_info['number_of_points'], np.float32) # make sure we have three block value if not isinstance(gpu_block, list): gpu_block = [gpu_block] * 3 # get the grid gpu_grid = [int(np.ceil(n / b)) for b, n in zip(gpu_block, grid_info['number_of_points'])] print('GPU BLOCK:', gpu_block) print('GPU GRID :', gpu_grid) xyz_center = np.random.rand(500, 3).astype(np.float32) alpha = np.float32(1) t0 = time() for xyz in xyz_center: x0, y0, z0 = xyz cuda_func(alpha, x0, y0, z0, x_gpu, y_gpu, z_gpu, grid_gpu, block=tuple(gpu_block), grid=tuple(gpu_grid)) print('Done in: %f ms' % ((time() - t0) * 1000))
# ==================================================================================== # # Routines needed to handle CUDA # # ====================================================================================
[docs] @staticmethod def _compile_cuda_kernel(cuda_kernel, npts, res): # pragma: no cover """Compile the cuda kernel. Args: cuda_kernel (str): filename npts (tuple(int)): number of grid points in each direction res (tuple(float)): resolution in each direction Returns: compiler.SourceModule: compiled kernel """ # get the cuda kernel path kernel = os.path.dirname( os.path.abspath(__file__)) + '/' + cuda_kernel kernel_code_template = open(kernel, 'r').read() kernel_code = kernel_code_template % { 'nx': npts[0], 'ny': npts[1], 'nz': npts[2], 'RES': np.max(res)} # compile the kernel mod = compiler.SourceModule(kernel_code) return mod
[docs] @staticmethod def _get_cuda_function(module, func_name): # pragma: no cover """Get a single function from the compiled kernel. Args: module (compiler.SourceModule): compiled kernel module func_name (str): Name of the funtion Returns: func: cuda function """ cuda_func = module.get_function(func_name) return cuda_func
# tranform the kernel to a tunable one
[docs] @staticmethod def _tunable_kernel(kernel): # pragma: no cover """Make a tunale kernel. Args: kernel (str): String of the kernel Returns: TYPE: tunable kernel """ switch_name = { 'blockDim.x': 'block_size_x', 'blockDim.y': 'block_size_y', 'blockDim.z': 'block_size_z'} for old, new in switch_name.items(): kernel = kernel.replace(old, new) return kernel
# ==================================================================================== # # FILTER DATASET # # ===================================================================================
[docs] def _filter_cplx(self): """Filter the name of the complexes.""" # read the class ID with open(self.pdb_select) as f: pdb_name = f.readlines() pdb_name = [name.split()[0] + '.pdb' for name in pdb_name] # create the filters tmp_path = [] for name in pdb_name: tmp_path += list(filter(lambda x: name in x, self.pdb_path)) # update the pdb_path self.pdb_path = tmp_path
# ==================================================================================== # # FEATURES ROUTINES # # ====================================================================================
[docs] @staticmethod def _compute_features(feat_list, pdb_data, featgrp, featgrp_raw, chain1, chain2, logger): """Compute the features. Args: feat_list (list(str)): list of function name, e.g., ['deeprank.features.ResidueDensity', 'deeprank.features.PSSM_IC'] pdb_data (bytes): PDB translated in bytes featgrp (str): name of the group where to store the xyz feature featgrp_raw (str): name of the group where to store the raw feature chain1 (str): First chain ID chain2 (str): Second chain ID logger (logger): name of logger object Return: bool: error happened or not """ error_flag = False # when False: success; when True: failed for feat in feat_list: try: feat_module = importlib.import_module(feat, package=None) feat_module.__compute_feature__(pdb_data, featgrp, featgrp_raw, chain1, chain2) except Exception as ex: logger.exception(ex) error_flag = True return error_flag
# ==================================================================================== # # TARGETS ROUTINES # # ====================================================================================
[docs] @staticmethod def _compute_targets(targ_list, pdb_data, targrp): """Compute the targets. Args: targ_list (list(str)): list of function name pdb_data (bytes): PDB translated in btes targrp (str): name of the group where to store the targets logger (logger): name of logger object """ for targ in targ_list: targ_module = importlib.import_module(targ, package=None) targ_module.__compute_target__(pdb_data, targrp)
# ==================================================================================== # # ADD PDB FILE # # ====================================================================================
[docs] def _add_pdb(self, molgrp, pdbfile, name): """Add a pdb to a molgrp. Args: molgrp (str): mopl group where tp add the pdb pdbfile (str): psb file to add name (str): dataset name in the hdf5 molgroup """ # no alignement if self.align is None: # read the pdb and extract the ATOM lines with open(pdbfile, 'r') as fi: data = [line.split('\n')[0] for line in fi if line.startswith('ATOM')] # some alignement elif isinstance(self.align, dict): sqldb = self._get_aligned_sqldb(pdbfile, self.align) data = sqldb.sql2pdb() # PDB default line length is 80 # http://www.wwpdb.org/documentation/file-format data = np.array(data).astype('|S78') molgrp.create_dataset(name, data=data)
# @staticmethod
[docs] def _get_aligned_sqldb(self, pdbfile, dict_align): """return a sqldb of the pdb that is aligned as specified in the dict Arguments: pdbfile {str} -- path ot the pdb dict_align {dict} -- dictionanry of options to align the pdb """ if 'selection' not in dict_align.keys(): dict_align['selection'] = {} if 'export' not in dict_align.keys(): dict_align['export'] = False if dict_align['selection'] == 'interface': sqldb = align_interface(pdbfile, plane=dict_align['plane'], export=dict_align['export'], chain1=self.chain1, chain2=self.chain2) else: sqldb = align_along_axis(pdbfile, axis=dict_align['axis'], export=dict_align['export'], **dict_align['selection']) return sqldb
# ==================================================================================== # # AUGMENTED DATA # # ====================================================================================
[docs] @staticmethod def _get_aligned_rotation_axis_angle(random_seed, dict_align): """Returns the axis and angle of rotation for data augmentation with aligned complexes Arguments: random_seed {int} -- random seed of rotation dict_align {dict} -- the dict describing the alignement Returns: list(float): axis of rotation float: angle of rotation """ if random_seed is not None: np.random.seed(random_seed) angle = 2 * np.pi * np.random.rand() if 'plane' in dict_align.keys(): if dict_align['plane'] == 'xy': axis = [0., 0., 1.] elif dict_align['plane'] == 'xz': axis = [0., 1., 0.] elif dict_align['plane'] == 'yz': axis = [1., 0., 0.] else: raise ValueError("plane must be xy, xz or yz") elif 'axis' in dict_align.keys(): if dict_align['axis'] == 'x': axis = [1., 0., 0.] elif dict_align['axis'] == 'y': axis = [0., 1., 0.] elif dict_align['axis'] == 'z': axis = [0., 0., 1.] else: raise ValueError("axis must be x, y or z") else: raise ValueError('dict_align must contains plane or axis') return axis, angle
# add a rotated pdb structure to the database
[docs] def _add_aug_pdb(self, molgrp, pdbfile, name, axis, angle): """Add augmented pdbs to the dataset. Args: molgrp (str): name of the molgroup pdbfile (str): pdb file name name (str): name of the dataset axis (list(float)): axis of rotation angle (float): angle of rotation dict_align (dict): dict for alignement of the original pdb Returns: list(float): center of the molecule """ # create the sqldb and extract positions if self.align is None: sqldb = pdb2sql.pdb2sql(pdbfile) else: sqldb = self._get_aligned_sqldb(pdbfile, self.align) # rotate the positions pdb2sql.transform.rot_axis(sqldb, axis, angle) # get molecule center xyz = sqldb.get('x,y,z') center = np.mean(xyz, 0) # get the pdb-format data data = sqldb.sql2pdb() data = np.array(data).astype('|S78') molgrp.create_dataset(name, data=data) # close the db sqldb._close() return center
# rotate th xyz-formatted feature in the database
[docs] @staticmethod def _rotate_feature(molgrp, axis, angle, center, feat_name='all'): """Rotate the raw feature values. Args: molgrp (str): name pf the molgrp axis (list(float)): axis of rotation angle (float): angle of rotation center (list(float)): center of rotation feat_name (str): name of the feature to rotate or 'all' """ if feat_name == 'all': feat = list(molgrp['features'].keys()) else: feat = feat_name if not isinstance(feat, list): feat = list(feat) for fn in feat: # extract the data data = molgrp['features/' + fn][()] # if data not empty if data.shape[0] != 0: # xyz xyz = data[:, 1:4] # get rotated xyz xyz_rot = pdb2sql.transform.rot_xyz_around_axis( xyz, axis, angle, center) # put back the data molgrp['features/' + fn][:, 1:4] = xyz_rot