src/common/rotamer.cc

00001 
00002 #include <stdio.h>
00003 #include <stdlib.h>
00004 #include <string.h>
00005 #include <math.h>
00006 #include <assert.h>
00007 #include <string>
00008 
00009 #include "open_prospect.h"
00010 #include "tree_decomp.h"
00011 #include "sm_matrix.h"
00012 
00013 
00014 
00015 int ResidueAtoms::GetAtomIndex( int atom_name_in ) {
00016         int cur_atom = -1;
00017         for ( int i = 0; i < atom_count; i++ ) {
00018                 if ( atom_name[i] == atom_name_in ) 
00019                         cur_atom = i;
00020         }       
00021         return cur_atom;        
00022 }
00023 
00024 void ResidueAtoms::SetAtomLoc( int atom_name_in, vec3 loc) {
00025         
00026         int cur_atom = -1;
00027         for ( int i = 0; i < atom_count; i++ ) {
00028                 if ( atom_name[i] == atom_name_in ) 
00029                         cur_atom = i;
00030         }       
00031         if ( cur_atom == -1 ) { 
00032                 cur_atom = atom_count;
00033                 atom_count++;
00034                 if ( atom_name_in >= 0 && atom_name_in < 4 ) {
00035                         bb_num[ atom_name_in ] = cur_atom;
00036                 }
00037                 atom_name[ cur_atom ] = atom_name_in;
00038         }
00039         
00040         atom[ cur_atom ][ 0 ] = loc[0];
00041         atom[ cur_atom ][ 1 ] = loc[1];
00042         atom[ cur_atom ][ 2 ] = loc[2];
00043 }
00044 
00045 void ResidueAtoms::SetTempFactor( int atom_index, float new_temp) {
00046         if (atom_index < atom_count) {
00047                 temp_factor[ atom_index ] = new_temp;
00048         }
00049 }
00050 
00051 
00052 void ResidueAtoms::ReadAtomLine( char * line ) {
00053         char tmp_x_str[30], tmp_y_str[30], tmp_z_str[30];
00054         strcpy(tmp_x_str, "");
00055         strcpy(tmp_y_str, "");
00056         strcpy(tmp_z_str, "");
00057         
00058         if ( !strncmp(line, "ATOM", 4) || !strncmp(line, "HETATM", 6) ) {                       
00059                 char tmp_str[50];
00060                 strncpy( tmp_str, &(line[22]), 5 );
00061                 tmp_str[5] = 0;
00062                 
00063                 if ( atom_count < kMaxResAtoms) {
00064                         //load in the atom num
00065                         strncpy( tmp_str, &(line[6]), 4 );
00066                         atom_num[ atom_count ] = atoi( tmp_str );
00067                         
00068                         strncpy(tmp_x_str, &(line[30]), 8);
00069                         tmp_x_str[8] = 0;
00070                         strncpy(tmp_y_str, &(line[38]), 8);
00071                         tmp_y_str[8] = 0;
00072                         strncpy(tmp_z_str, &(line[46]), 8);
00073                         tmp_z_str[8] = 0;
00074                         //location information/////////////////////////////
00075                         atom[atom_count][0] = atof(tmp_x_str);
00076                         atom[atom_count][1] = atof(tmp_y_str);
00077                         atom[atom_count][2] = atof(tmp_z_str);
00078                         
00079                         //Temperature factor/////////////////////////////
00080                         strncpy( tmp_str, &(line[60]), 6);
00081                         tmp_str[6] = 0;
00082                         temp_factor[ atom_count ] = atof( tmp_str );
00083                         
00084                         //residue id//////////////////////////////////////
00085                         strncpy(tmp_str, &(line[22]), 4 );
00086                         tmp_str[4] = 0;
00087                         res_num = atoi( tmp_str );
00088                         res_mod = line[26];
00089                         
00090                         //the_data->atom[i].residue[5] = 0;
00091                         //clean_space( the_data->atom[i].residue );
00092                         
00093                         //atom name//////////////////////////////////////
00094                         int atom_namenum = PDBAtomtoNum( &(line[13]) );
00095                         if ( atom_namenum >= 0 ) {
00096                                 atom_name[ atom_count ] = atom_namenum;
00097                                 if ( atom_namenum == anN || atom_namenum == anCa || 
00098                                          atom_namenum == anC || atom_namenum == anO  ) {
00099                                         bb_num[ atom_namenum ] = atom_count;
00100                                 }
00101                                 atom_count++;
00102                         }
00103                 }
00104         }       
00105 }
00106 
00107 
00108 
00109 
00110 char *ResidueAtoms::WriteAtomLine(int i, char RS, int res_num, char chain) {
00111         if ( i > atom_count )
00112                 return NULL;
00113         char *buffer = (char *)malloc( 200 );
00114         char tmp[10];
00115         AA1toAA3(RS, tmp);                      
00116         if (chain == 0)
00117                 chain = ' ';
00118         sprintf(buffer, "ATOM  %5d  %-3s %+3s %c%5d   %8.3f%8.3f%8.3f  1.00%6.2f\n", 
00119                         atom_num[i], 
00120                         pdb_atomname[ atom_name[i] ], 
00121                         tmp, 
00122                         chain, 
00123                         res_num,
00124                         atom[i][0], 
00125                         atom[i][1], 
00126                         atom[i][2],
00127                         temp_factor[i] );
00128         return buffer;
00129 }
00130 
00131 
00132 
00133 
00134 void smTranslateCoordinates(vec3 pre_anchor[3], vec3 cur_anchor[3], vec3 A[], int N) {
00135         
00136         int i;
00137         vec3 pre_position, cur_position;
00138         float matrix1[3][3], matrix2[3][3], matrix3[3][3], matrix4[3][3];
00139         
00140         smCopy(pre_anchor[anCa], pre_position);
00141         smCopy(cur_anchor[anCa], cur_position);
00142         smMinus(cur_position);
00143         
00144         //      smComputeConvertMatrix(pre_anchor[0], pre_anchor[1], pre_anchor[2], matrix1);
00145         //      smComputeConvertMatrix(cur_anchor[0], cur_anchor[1], cur_anchor[2], matrix2);
00146         
00147         smComputeConvertMatrix(pre_anchor[anCa], pre_anchor[anN], pre_anchor[anC], matrix1);
00148         smComputeConvertMatrix(cur_anchor[anCa], cur_anchor[anN], cur_anchor[anC], matrix2);
00149         
00150         
00151         smInverseMatrix(matrix2, matrix3);
00152         smMultMat(matrix3, matrix1, matrix4);
00153         //      smMultiMatrix(matrix2, matrix1, matrix4, 3);
00154         
00155         for (i=0; i<N; i++)
00156                 smAdd(cur_position, A[i]);
00157         for (i=0; i<N; i++)
00158                 smMultVec(A[i], matrix4);
00159         for (i=0; i<N; i++)
00160                 smAdd(pre_position, A[i]);
00161         return;
00162 }
00163 
00164 
00165 int Prospect_Target_Rotamer(ProspectParam *param, ResidueAtoms *structure, char residue, int rotamer_number) {
00166         //FIXME: need to check this code for compatibility with new residue structure format    
00167         //first step, copy in residue  lib co-ordinates
00168         vec3 a[20];
00169         int res_num = AA1toAANum(residue);
00170         if (res_num >= 0 && res_num < 20) {
00171                 ResidueInfoElement *res_info = &param->residue_info[ res_num ];                 
00172                 //copy atoms not part of the backbone
00173                 int atom_count = 0;
00174                 for (int i = 0; i < res_info->structure.atom_count; i++) {
00175                         if ( res_info->structure.atom_name[i] > anBBcount ) {
00176                                 a[ atom_count ][0] = res_info->structure.atom[i][0];
00177                                 a[ atom_count ][1] = res_info->structure.atom[i][1];
00178                                 a[ atom_count ][2] = res_info->structure.atom[i][2];                    
00179                                 atom_count++;
00180                         }
00181                 }               
00182                 assert( atom_count < 20);
00183                 
00184                 //apply rotamer angle adjustments               
00185                 
00186                 //translate the rotamer to the backbone
00187                 smTranslateCoordinates( structure->atom, res_info->structure.atom, a, atom_count);              
00188                 for (int i = 0; i < atom_count; i++) {
00189                         structure->atom[i ][0] = a[i][0]; //FIXME: need offset here
00190                         structure->atom[i ][1] = a[i][1];
00191                         structure->atom[i ][2] = a[i][2];
00192                 }               
00193                 if ( residue != 'G' && residue != 'P' ) {
00194                         RotamerElement *rotamer = &param->rotamer_array[ res_num ].rotamer[ rotamer_number ];
00195                         //      assert( res_info->torsion_count == param->rotamer_array[ res_num ].chi_count );
00196                         for (int i = 0; i < res_info->torsion_count; i++) {
00197                                 //                      for (int i = 0; i < res_info->torsion_count && i < 1; i++) 
00198                                 int tor_a = res_info->torsion_a[i];
00199                                 int tor_b = res_info->torsion_b[i];
00200                                 
00201                                 vec3 axis;
00202                                 smSub(structure->atom[tor_b], structure->atom[tor_a], axis);
00203                                 smNormalize(axis);
00204                                 float rot_mat[3][3];
00205                                 smMakeRotationMatrix( (rotamer->chi[i] - res_info->torsion_val[i])/180.0*M_PI, axis, rot_mat);
00206                                 assert( !isnan(rot_mat[0][0]) );
00207                                 //      smMakeRotationMatrix(  res_info->torsion_val[i] - rotamer->chi[i], axis, rot_mat);
00208                                 /* Do rotation */
00209                                 for(int j=0; j < structure->atom_count; j++) {
00210                                         if ( res_info->torsion_mask[i][j] ) {
00211                                                 vec3 tmp;
00212                                                 tmp[0] = structure->atom[j][0] - structure->atom[tor_a][0];
00213                                                 tmp[1] = structure->atom[j][1] - structure->atom[tor_a][1];
00214                                                 tmp[2] = structure->atom[j][2] - structure->atom[tor_a][2];
00215                                                 structure->atom[j][0] = rot_mat[0][0]*tmp[0] + rot_mat[0][1]*tmp[1] + rot_mat[0][2]*tmp[2] + structure->atom[tor_a][0];
00216                                                 structure->atom[j][1] = rot_mat[1][0]*tmp[0] + rot_mat[1][1]*tmp[1] + rot_mat[1][2]*tmp[2] + structure->atom[tor_a][1];
00217                                                 structure->atom[j][2] = rot_mat[2][0]*tmp[0] + rot_mat[2][1]*tmp[1] + rot_mat[2][2]*tmp[2] + structure->atom[tor_a][2];
00218                                                 assert( !isnan(structure->atom[j][0]) && !isnan(structure->atom[j][1]) && !isnan(structure->atom[j][2]));
00219                                         }
00220                                 }
00221                         }
00222                 }               
00223         }
00224         return 0;
00225 }
00226 
00227 
00228 pValType calc_single_energy(TargetStruct *target_data, ResidueAtoms *res_struct) {
00229         pValType energy = 0.0;
00230         for (int atom_a = 0;
00231                  atom_a < res_struct->atom_count; atom_a++) {
00232                 if ( res_struct->atom_name[ atom_a ] >= anBBcount ) {
00233                         for (int res_b = 0; res_b < target_data->len; res_b++) {
00234                                 for (int atom_b = 0; atom_b < target_data->structure[res_b].atom_count; atom_b++) {
00235                                         if ( res_struct->atom_name[ atom_b ] < anBBcount ) {
00236                                                 energy += calc_LRET( res_struct->atom[atom_a], 
00237                                                                                          target_data->structure[ res_b ].atom[atom_b],
00238                                                                                          res_struct->radius[atom_a], 
00239                                                                                          target_data->structure[ res_b ].radius[atom_b] );
00240                                         }
00241                                 }
00242                         }
00243                 }
00244         }
00245         return energy;
00246 }
00247 
00248 
00249 
00250 pValType calc_pair_energy(ResidueAtoms *res_a, ResidueAtoms *res_b) {
00251         pValType energy = 0.0;
00252         for (int atom_a = 0;
00253                  atom_a < res_a->atom_count; atom_a++) {
00254                 if (res_a->atom_name[ atom_a ] >= anBBcount) {
00255                         for (int atom_b = 0;
00256                                  atom_b < res_b->atom_count; atom_b++) {
00257                                 if (res_a->atom_name[ atom_b ] >= anBBcount) {
00258                                         energy += calc_LRET( res_a->atom[atom_a], 
00259                                                                                  res_b->atom[atom_b], 
00260                                                                                  res_a->radius[atom_a], 
00261                                                                                  res_b->radius[atom_b] );
00262                                 }
00263                         }
00264                 }
00265         }
00266         return energy;
00267 }
00268 
00269 
00270 
00271 //#define target_hash( x, y )  (( x * target_data->len + y))
00272 
00273 
00274 
00275 
00276 int rot_mask_active(RotamerMask *mask, int i) {
00277         int count = 0;
00278         for (int j = 0; j < mask->count[i]; j++) {
00279                 if ( mask->state[i][j] ) 
00280                         count++;
00281         }
00282         return count;
00283 }
00284 
00285 
00286 
00287 
00288 int Prospect_Target_RotamerMask_Init(ProspectParam *param, TargetStruct *target_data, RotamerMask *rotamer_mask) {
00289         int total_count = 0;
00290         
00291         rotamer_mask->count = (long *)malloc(sizeof(long) * target_data->len);
00292         for (int i = 0; i < target_data->len; i++) {
00293                 rotamer_mask->count[i] = 0;
00294         }
00295         if (  target_data->structure == NULL ) 
00296                 return 1;
00297         
00298         for (int i = 0; i < target_data->len; i++) {
00299                 if ( target_data->structure[i].atom_count >= 4 ) {
00300                         char res_num = AA1toAANum(target_data->residue[i].RS) ;
00301                         for (int j = 0; j < param->rotamer_array[ res_num ].rotamer_count; j++) {
00302                                 if ( fabs(target_data->residue[i].psi - param->rotamer_array[ res_num ].rotamer[j].psi) < param->rotamer_array[res_num].psi_bin &&
00303                                          fabs(target_data->residue[i].phi - param->rotamer_array[ res_num ].rotamer[j].phi) < param->rotamer_array[res_num].phi_bin ) {
00304                                         rotamer_mask->count[i]++;
00305                                         total_count++;
00306                                 }
00307                         }
00308                 } else {
00309                         rotamer_mask->count[i] = 0;             
00310                 }
00311         }
00312         rotamer_mask->len = target_data->len;
00313         rotamer_mask->state = (char **)malloc(sizeof(char *) * target_data->len);
00314         rotamer_mask->state[0] = (char *)malloc(sizeof(char) * total_count);
00315         rotamer_mask->rot_num = (long **)malloc(sizeof(long *) * target_data->len);
00316         rotamer_mask->rot_num[0] = (long *)malloc(sizeof(long) * total_count);
00317         rotamer_mask->structure = (ResidueAtoms **)malloc(sizeof(ResidueAtoms *) * target_data->len);
00318         rotamer_mask->structure[0] = (ResidueAtoms *)malloc(sizeof(ResidueAtoms) * total_count);
00319         rotamer_mask->single_score = (pValType **)malloc(sizeof(pValType *) * target_data->len);
00320         rotamer_mask->single_score[0] = (pValType *)malloc(sizeof(pValType) * total_count);
00321         
00322         int cur_pos = 0;
00323         for (int i = 0; i < target_data->len; i++) {
00324                 if ( target_data->structure[i].atom_count >= 4) {
00325                         rotamer_mask->state[i] = rotamer_mask->state[0] + cur_pos;
00326                         rotamer_mask->structure[i] = rotamer_mask->structure[0] + cur_pos;
00327                         rotamer_mask->single_score[i] = rotamer_mask->single_score[0] + cur_pos;
00328                         rotamer_mask->rot_num[i] = rotamer_mask->rot_num[0] + cur_pos;
00329                 }
00330                 cur_pos += rotamer_mask->count[i];
00331         }
00332         
00333         for (int i = 0; i < target_data->len; i++) {
00334                 if  ( target_data->structure[i].atom_count >= 4) {
00335                         
00336                         char res_num = AA1toAANum(target_data->residue[i].RS) ;
00337                         int rot_count = 0;
00338                         for (int j = 0; j < param->rotamer_array[ res_num ].rotamer_count; j++) {
00339                                 if ( fabs(target_data->residue[i].psi - param->rotamer_array[ res_num ].rotamer[j].psi) < param->rotamer_array[res_num].psi_bin &&
00340                                          fabs(target_data->residue[i].phi - param->rotamer_array[ res_num ].rotamer[j].phi) < param->rotamer_array[res_num].phi_bin ) {
00341                                         rotamer_mask->rot_num[i][ rot_count ] = j;
00342                                         rot_count++;
00343                                 }
00344                         }
00345                         for (int j =0; j < rotamer_mask->count[i]; j++) {
00346                                 rotamer_mask->state[i][j] = 1;
00347                                 rotamer_mask->structure[i][j] = target_data->structure[i];
00348                                 Prospect_Target_Rotamer(param, &rotamer_mask->structure[i][j], target_data->residue[i].RS, rotamer_mask->rot_num[i][j]);
00349                                 rotamer_mask->single_score[i][j] = calc_single_energy(target_data, &rotamer_mask->structure[i][j]);
00350                         }
00351                 }
00352         }
00353         
00354         
00355         
00356         rotamer_mask->pair_score = (pValType ***)malloc(sizeof(pValType **) * target_data->len * target_data->len);
00357         for (int i = 0; i < target_data->len; i++) {
00358                 if  ( target_data->structure[i].atom_count >= 4)  {
00359                         for (int j = 0; j < target_data->len; j++) {
00360                                 if  ( j > i ) {
00361                                         if (  ( target_data->structure[i].atom_count >= 4)  && i != j) {
00362                                                 //int offset =  target_hash( i, j);
00363                                                 int offset = i * target_data->len + j;
00364                                                 rotamer_mask->pair_score[offset] = (pValType **)malloc(sizeof(pValType *) * rotamer_mask->count[i]  );
00365                                                 rotamer_mask->pair_score[offset][0] = (pValType *)malloc(sizeof(pValType) * rotamer_mask->count[i] * rotamer_mask->count[j]);
00366                                                 for (int  k = 0; k < rotamer_mask->count[i]; k++ ) {
00367                                                         rotamer_mask->pair_score[offset][k] = rotamer_mask->pair_score[offset][0] + k * rotamer_mask->count[j];  
00368                                                 }
00369                                         } else {
00370                                                 //rotamer_mask->pair_score[ target_hash( i, j) ] = NULL;                                        
00371                                                 int offset = i * target_data->len + j;
00372                                                 rotamer_mask->pair_score[ offset ] = NULL;                      
00373                                         }
00374                                 } else {
00375                                         int offset = i * target_data->len + j;
00376                                         rotamer_mask->pair_score[ offset ] = NULL;
00377                                 }
00378                         }
00379                 } else {
00380                         for (int j = 0; j < target_data->len; j++) {
00381                                 //rotamer_mask->pair_score[ target_hash( i, j) ] = NULL;
00382                                 int offset = i * target_data->len + j;
00383                                 rotamer_mask->pair_score[ offset ] = NULL;
00384                         }
00385                 }
00386         }               
00387         
00388         for (int i = 0; i < target_data->len; i++) {
00389                 if  ( target_data->structure[i].atom_count >= 4)  {
00390                         for (int j = i+1; j < target_data->len; j++) {
00391                                 if  ( target_data->structure[i].atom_count >= 4)  {
00392                                         float dist2 = pow( target_data->structure[i].atom[ anCb ][0] - target_data->structure[j].atom[ anCb ][0], 2) +
00393                                         pow( target_data->structure[i].atom[ anCb ][1] - target_data->structure[j].atom[ anCb ][1], 2) +
00394                                         pow( target_data->structure[i].atom[ anCb ][2] - target_data->structure[j].atom[ anCb ][2], 2) ;
00395                                         if ( dist2 < 240.25) {  //15.5^2  angstoms
00396                                                 for (int i_r = 0; i_r < rotamer_mask->count[i]; i_r++) {
00397                                                         //int offset = target_hash( i, j);
00398                                                         for (int j_r = 0; j_r < rotamer_mask->count[j]; j_r++) {
00399                                                                 //rotamer_mask->pair_score[offset][i_r][j_r] = 
00400                                                                 //calc_pair_energy(&rotamer_mask->structure[i][i_r], &rotamer_mask->structure[j][j_r]);
00401                                                                 rot_mask_pairscore_set( rotamer_mask, i, j, i_r, j_r, 
00402                                                                                                                 calc_pair_energy(&rotamer_mask->structure[i][i_r],
00403                                                                                                                                                  &rotamer_mask->structure[j][j_r]) 
00404                                                                                                                 );
00405                                                         }
00406                                                 }               
00407                                         } else {
00408                                                 for (int i_r = 0; i_r < rotamer_mask->count[i]; i_r++) {
00409                                                         for (int j_r = 0; j_r < rotamer_mask->count[j]; j_r++) {
00410                                                                 rot_mask_pairscore_set( rotamer_mask, i, j, i_r, j_r, 0.0 );
00411                                                         }
00412                                                 }
00413                                         }
00414                                 }               
00415                         }                       
00416                 }
00417         }       
00418         
00419         
00420         
00421         /*
00422          for (int i = 0; i < target_data->len; i++) {
00423                  for (int j =0; j < rotamer_mask->count[i]; j++) {
00424                          rotamer_mask->pair_score[i][j][
00425                                  */     
00426         
00427 
00428         return 0;
00429 }
00430 
00431 
00432 void Prospect_Target_RotamerMask_Free(RotamerMask *mask) {
00433         free( mask->count );
00434         free(mask->state[0] );
00435         free(mask->state );
00436         free(mask->rot_num[0] );
00437         free(mask->rot_num  );
00438         free(mask->structure[0] );
00439         free(mask->structure  );
00440         free(mask->single_score[0] );
00441         free(mask->single_score  );
00442         for (int i = 0; i < mask->len; i++) {
00443                 for (int j = 0; j < mask->len; j++) {
00444                         if ( mask->pair_score[ i * mask->len + j ] != NULL )
00445                                 free( mask->pair_score[ i * mask->len + j ] );
00446                 }
00447         }
00448 }
00449 
00450 pValType rot_mask_pairscore_get( RotamerMask *mask, long vertex_1, long vertex_2, long state_1, long state_2) {
00451         if (vertex_2 > vertex_1) {
00452                 return mask->pair_score[ vertex_1 * mask->len + vertex_2 ][ state_1 ][ state_2 ];               
00453         }
00454         return mask->pair_score[ vertex_2 * mask->len + vertex_1 ][ state_2 ][ state_1 ];       
00455 }
00456 
00457 
00458 void rot_mask_pairscore_set( RotamerMask *mask, long vertex_1, long vertex_2, long state_1, long state_2, pValType score) {
00459         if (vertex_2 > vertex_1) {
00460                 mask->pair_score[ vertex_1 * mask->len + vertex_2 ][ state_1 ][ state_2 ] = score;              
00461         } else {
00462                 mask->pair_score[ vertex_2 * mask->len + vertex_1 ][ state_2 ][ state_1 ] = score;      
00463         }
00464 }
00465 
00466 
00467 /*
00468  void Prospect_Target_RotamerMask_Compress(RotamerMask *mask) {
00469          for (int i = 0; i < mask->len; i++) {
00470                  int k = 0; 
00471                  for (int j = 0; j < mask->count[i]; j++) {
00472                          if ( mask->state[i][j] ) {
00473                                  mask->state[i][k] = mask->state[i][j];
00474                                  mask->rot_num[i][k] = mask->rot_num[i][j];
00475                                  mask->structure[i][k] = mask->structure[i][j];
00476                                  k++;
00477                          }
00478                  }
00479          }
00480  }
00481  */
00482 
00483 
00484 int Prospect_RotamerMask_DEE(RotamerMask *mask) {
00485         
00486                 
00487         char removal_found = 0; 
00488         do {
00489                 removal_found = 0;
00490                 for (int i = 0; i < mask->len; i++) {
00491                         if ( mask->count[i] != 0) {
00492                                 //fprintf(stderr, "Residue %d (%d rotamers)\n", i, mask->count[i]);
00493                                 for (int rot_a = 0; rot_a < mask->count[i]; rot_a++) {
00494                                         if ( mask->state[i][rot_a] ) {
00495                                                 float e_single_a = mask->single_score[i][rot_a];
00496                                                 for (int rot_b = 0; rot_b < mask->count[i]; rot_b++) {
00497                                                         if (rot_a != rot_b && mask->state[i][rot_b]) {
00498                                                                 float e_single_b = mask->single_score[i][rot_b];                                                        
00499                                                                 float e_pair_sum = 0;
00500                                                                 for ( int res_c = 0; res_c < mask->len; res_c++ ) {
00501                                                                         if ( res_c != i && mask->count[res_c] != 0 ) {
00502                                                                                 float min = 1000000000;
00503                                                                                 //int offset = i * mask->len + res_c;//target_hash( i, res_c );
00504                                                                                 for ( int rot_c = 0; rot_c < mask->count[res_c]; rot_c++) {
00505                                                                                         if ( mask->state[res_c][rot_c] ) {                                                                                                                                      
00506                                                                                                 //float e_pair_a = mask->pair_score[ offset ][ rot_a ][ rot_c ];
00507                                                                                                 //float e_pair_b = mask->pair_score[ offset ][ rot_b ][ rot_c ];
00508                                                                                                 float e_pair_a = rot_mask_pairscore_get( mask, i, res_c, rot_a, rot_c );
00509                                                                                                 float e_pair_b = rot_mask_pairscore_get( mask, i, res_c, rot_b, rot_c );                                                                                                                                                                                                                                                                         
00510                                                                                                 float tmp = e_pair_b - e_pair_a;
00511                                                                                                 if ( tmp < min ) {
00512                                                                                                         min = tmp;
00513                                                                                                 }
00514                                                                                         }
00515                                                                                 }
00516                                                                                 e_pair_sum += min;
00517                                                                         }
00518                                                                 }                                                       
00519                                                                 if ( e_single_b - e_single_a + e_pair_sum > 0 ) {
00520                                                                         //fprintf(stderr, "Eliminate res: %d rotamer: %d by %d (%f)\n", i, rot_b, rot_a, e_single_b - e_single_a + e_pair_sum);
00521                                                                         mask->state[i][rot_b] = 0;
00522                                                                         removal_found = 1;
00523                                                                 }
00524                                                         }
00525                                                 }
00526                                         }
00527                                 }
00528                         } else {
00529                                 //fprintf(stderr, "Residue %d skipped\n", i);
00530                         }
00531                 }
00532         } while (removal_found);        
00533         return 0;
00534 }
00535 
00536 
00537 /*
00538 int Prospect_RotamerMask_DEE2(RotamerMask *mask) {
00539         
00540         char removal_found = 0;
00541         
00542         do {
00543                 removal_found = 0;
00544                 for (int i = 0; i < mask->len; i++) {
00545                         int i_count = rot_mask_active(mask, i);                 
00546                         if ( i_count > 1 ) {
00547                                 for (int j = 0; j < mask->len; j++) {
00548                                         int j_count = rot_mask_active(mask, j);
00549                                         if (i != j && j_count > 1) {
00550                                                 for ( int ir_1 = 0; ir_1 < mask->count[i]; ir_1++) {
00551                                                         if ( mask->state[i][ ir_1 ] ) {
00552                                                                 for ( int jr_1 = 0; jr_1 < mask->count[j]; jr_1++) {
00553                                                                         if ( ir_1 != jr_1 && mask->state[j][ jr_1 ] ) {
00554                                                                                 long offset_1 = i * mask->len + j;
00555                                                                                 float val_1 = mask->single_score[i][ir_1] + mask->single_score[j][jr_1] + mask->pair_score[offset_1][ir_1][jr_1];
00556                                                                                 for ( int ir_2 = 0; ir_2 < mask->count[i]; ir_2++) {
00557                                                                                         if ( ir_1 != ir_2 && mask->state[i][ ir_2 ] ) {
00558                                                                                                 for ( int jr_2 = 0; jr_2 < mask->count[j]; jr_2++) {
00559                                                                                                         if ( jr_1 != jr_2 && mask->state[j][ jr_2 ] ) {
00560                                                                                                                 float val_2 = mask->single_score[i][ir_2] + mask->single_score[j][jr_2] + mask->pair_score[offset_1][ir_2][jr_2];
00561                                                                                                                 
00562                                                                                                                 for (int k = 0; k < mask->len; k++) {
00563                                                                                                                         if (k != i && k != j ) {
00564                                                                                                                                 
00565                                                                                                                         }
00566                                                                                                                 }       
00567                                                                                                         }
00568                                                                                                 }
00569                                                                                         }
00570                                                                                 }
00571                                                                         }
00572                                                                 }
00573                                                         }
00574                                                 }
00575                                         }
00576                                 }
00577                         }
00578                 }
00579         } while (removal_found);
00580         return 0;
00581 }
00582 */
00583 
00584 int Prospect_RotamerMask_DEE_Split(RotamerMask *mask) {
00585         
00586         int max_rotcount = 0;
00587         for (int i = 0; i < mask->len; i++) {
00588                 if ( max_rotcount < mask->count[i])
00589                         max_rotcount = mask->count[i];
00590         }
00591         
00592         float y_val[ mask->len][ max_rotcount ];                
00593         char y_val_set[ mask->len][ max_rotcount ];     
00594 
00595         char removal_found = 0; 
00596         do {
00597                 removal_found = 0;
00598                 for (int i = 0; i < mask->len; i++) {
00599                         for ( int ir_1 = 0; ir_1 < mask->count[i]; ir_1++) {
00600                                 if ( mask->state[i][ ir_1 ] ) {
00601                                         for ( int ir_2 = 0; ir_2 < mask->count[i]; ir_2++) {
00602                                                 if ( ir_1 != ir_2 && mask->state[i][ ir_2 ] ) {
00603                                                         for (int j = 0; j < mask->len; j++) {
00604                                                                 if ( i != j ) {
00605                                                                         y_val[j][ir_2] = 1e20;
00606                                                                         y_val_set[j][ir_2] = 0;
00607                                                                         for ( int jr_1 = 0; jr_1 < mask->count[j]; jr_1++) {
00608                                                                                 if ( mask->state[j][ jr_1 ] ) {
00609                                                                                         float val1 = rot_mask_pairscore_get( mask, i, j, ir_1, jr_1 );
00610                                                                                         float val2 = rot_mask_pairscore_get( mask, i, j, ir_2, jr_1 );
00611                                                                                         float test_val = val1 - val2;
00612                                                                                         if ( test_val < y_val[j][ir_2] ) {
00613                                                                                                 y_val[j][ir_2] = test_val;
00614                                                                                                 y_val_set[j][ir_2] = 1;
00615                                                                                         }
00616                                                                                 }
00617                                                                         }
00618                                                                 }
00619                                                         }                                                       
00620                                                 }
00621                                         }                               
00622                                         for (int k = 0; k < mask->len; k++) {
00623                                                 if ( k != i && rot_mask_active(mask, k) > 0 ) {
00624                                                         char v_mask[ max_rotcount ];
00625                                                         for (int tmp = 0; tmp < mask->count[k]; tmp++) 
00626                                                                 v_mask[tmp] = 0;
00627                                                         for ( int ir_2 = 0; ir_2 < mask->count[i]; ir_2++) {
00628                                                                 if ( ir_1 != ir_2 && mask->state[i][ ir_2 ] ) {
00629                                                                         float val_x = mask->single_score[i][ir_1] - mask->single_score[i][ir_2];
00630                                                                         for (int j = 0; j < mask->len; j++) {
00631                                                                                 if ( j != i && j != k) {
00632                                                                                         if ( y_val_set[j][ir_2] )
00633                                                                                                 val_x += y_val[j][ir_2];
00634                                                                                 }
00635                                                                         }
00636                                                                         for ( int kr_1 = 0; kr_1 < mask->count[k]; kr_1++) {
00637                                                                                 if ( mask->state[k][ kr_1 ] ) {
00638                                                                                         float val1 = rot_mask_pairscore_get( mask, i, k, ir_1, kr_1 );
00639                                                                                         float val2 = rot_mask_pairscore_get( mask, i, k, ir_2, kr_1 );
00640                                                                                         if ( val_x + (val1 - val2) > 0 ) {
00641                                                                                                 v_mask[ kr_1 ] = 1;
00642                                                                                         }
00643                                                                                 }
00644                                                                         }
00645                                                                 }
00646                                                         }
00647                                                         char eliminate = 1;
00648                                                         for (int tmp = 0; tmp < mask->count[k]; tmp++) {
00649                                                                 if ( mask->state[k][tmp] && !v_mask[ tmp ]) {
00650                                                                         eliminate = 0;
00651                                                                 }
00652                                                         }
00653                                                         if (eliminate) {
00654                                                                 mask->state[i][ir_1] = 0;
00655                                                                 k = mask->len;
00656                                                                 removal_found = 1;
00657                                                         }
00658                                                 }
00659                                         }
00660                                 }                       
00661                         }
00662                 }
00663         } while (removal_found);        
00664         return 0;
00665 }
00666 
00667 
00668 void sc_tree_print(char **matrix, long matrix_size, long cur_node, long source) {
00669         int i;
00670         
00671         for (i = 0; i < matrix_size; i++) {
00672                 if (cur_node != i && i != source && matrix[cur_node][i]) {
00673                         fprintf(stderr, "%ld -> %d\n", cur_node, i);
00674                         sc_tree_print(matrix, matrix_size, i, cur_node);                        
00675                 }
00676         }
00677 }
00678 
00679 
00680 int Prospect_RotamerMask_TreeDecomp( RotamerMask *mask ) {      
00681         
00682         //figure out the tree decomposition of the problem      
00683         //first, create the graph of connected points
00684         long vertex_count = 0;
00685         for (int i = 0; i < mask->len; i++) {
00686                 if ( mask->count[i] != 0) {
00687                         int count = 0;
00688                         for (int rot_a = 0; rot_a < mask->count[i]; rot_a++) {
00689                                 if ( mask->state[i][rot_a] ) {
00690                                         count++;
00691                                 }
00692                         }
00693                         char connect = 0;
00694                         for (int j = 0; j < mask->len; j++) {
00695                                 if ( j != i ) {
00696                                         //int offset = i * mask->len + j;//target_hash( i, j );
00697                                         for (int ir =0; ir < mask->count[i]; ir++) {
00698                                                 if ( mask->state[i][ir])  {
00699                                                         for (int jr =0; jr < mask->count[j]; jr++) {
00700                                                                 if ( mask->state[j][jr])  {
00701                                                                         //if (mask->pair_score[ offset ][ ir ][ jr ] > 0 ) 
00702                                                                         if ( rot_mask_pairscore_get( mask, i, j, ir, jr ) > 0 ) {
00703                                                                                 connect = 1;
00704                                                                         }
00705                                                                 }
00706                                                         }
00707                                                 }
00708                                         }
00709                                 }
00710                         }
00711                         if (connect) {
00712                                 vertex_count++;                         
00713                         }
00714                 }
00715         }
00716         
00717         
00718         //do the vertex hashes for the residues, and count the number of states
00719         long *vertex_hash = (long *)malloc(sizeof(long) * vertex_count );
00720         long *state_count = (long *)malloc(sizeof(long) * vertex_count );
00721         long total_state_count = 0;
00722         vertex_count = 0;       
00723         for (int i = 0; i < mask->len; i++) {
00724                 if ( mask->count[i] != 0) {
00725                         int count = 0;
00726                         for (int rot_a = 0; rot_a < mask->count[i]; rot_a++) {
00727                                 if ( mask->state[i][rot_a] ) {
00728                                         count++;
00729                                 }
00730                         }
00731                         char connect = 0;
00732                         for (int j = 0; j < mask->len; j++) {
00733                                 if ( j != i ) {
00734 //                                      int offset = i * mask->len + j;//target_hash( i, j );
00735                                         for (int ir =0; ir < mask->count[i]; ir++) {
00736                                                 if ( mask->state[i][ir])  {
00737                                                         for (int jr =0; jr < mask->count[j]; jr++) {
00738                                                                 if ( mask->state[j][jr])  {
00739                                                                         //if (mask->pair_score[ offset ][ ir ][ jr ] > 0 ) 
00740                                                                         if (rot_mask_pairscore_get( mask, i, j, ir, jr) > 0 ) {
00741                                                                                 connect = 1;
00742                                                                         }
00743                                                                 }
00744                                                         }
00745                                                 }
00746                                         }
00747                                 }
00748                         }
00749                         if (connect) {
00750                                 vertex_hash[ vertex_count ] = i;
00751                                 state_count[ vertex_count ] = count;
00752                                 total_state_count += count;
00753                                 vertex_count++;
00754                         }
00755                 }
00756         }
00757         
00758         //we also need to create the rotamer hash
00759         long **rot_hash = (long **)malloc(sizeof(long *) * vertex_count);
00760         rot_hash[0] = (long *)malloc(sizeof(long) * total_state_count);
00761         total_state_count = 0;
00762         for (int i = 0; i < vertex_count; i++) {
00763                 rot_hash[i] = rot_hash[0] + total_state_count;
00764                 int k = 0;
00765                 for (int j = 0; j <  mask->count[ vertex_hash[i] ]; j++) {
00766                         if ( mask->state[ vertex_hash[i] ][j] ) {
00767                                 rot_hash[i][k] = j;
00768                                 k++;
00769                         }
00770                 }
00771                 total_state_count += state_count[ i ];
00772         }
00773         
00774         
00775         //draw the graph connection matrix
00776         char **graph_matrix = td_matrix_init( vertex_count );
00777         for (int i = 0; i < vertex_count; i++) {
00778                 for (int j = i+1; j < vertex_count; j++) {
00779                         int vi = vertex_hash[i];
00780                         int vj = vertex_hash[j];                        
00781                         //int offset = vi * mask->len + vj;   //target_hash( vi, vj )
00782                         char connect = 0;
00783                         for (int ir = 0; ir < state_count[i]; ir++) {
00784                                 int vir = rot_hash[i][ir];
00785                                 for (int jr =0; jr < state_count[j]; jr++) {
00786                                         int vjr = rot_hash[j][jr];
00787                                         //if (mask->pair_score[ offset ][ vir ][ vjr ] > 0 ) 
00788                                         if ( rot_mask_pairscore_get( mask, vi, vj, vir, vjr ) > 0 ) 
00789                                                 connect = 1;
00790                                 }
00791                         }
00792                         if (connect) {
00793                                 graph_matrix[i][j] = 1;
00794                                 graph_matrix[j][i] = 1;
00795                         }
00796                 }
00797         }       
00798         
00799         td_graph *sc_graph = td_graph_init( vertex_count, state_count, graph_matrix );  
00800         
00801         //fill in the score values of the graph
00802         for (int i = 0; i < vertex_count; i++) {
00803                 int vi = vertex_hash[i];
00804                 for (int ir = 0; ir < state_count[i]; ir++) {
00805                         int vir = rot_hash[i][ir];
00806                         sc_graph->vertex_score[i][ir] = mask->single_score[ vi ][vir];
00807                 }               
00808                 for (int j = i+1; j < vertex_count; j++) {
00809                         if ( sc_graph->matrix[i][j]) {
00810                                 int vj = vertex_hash[j];
00811                                 long graph_offset = td_edge_offset( sc_graph, i, j );
00812                                 //long mask_offset = vi * mask->len + vj;  //target_hash( vi, vj );
00813                                 for (int ir = 0; ir < state_count[i]; ir++) {
00814                                         int vir = rot_hash[i][ir];
00815                                         for (int jr =0; jr < state_count[j]; jr++) {
00816                                                 int vjr = rot_hash[j][jr];
00817                                                 //sc_graph->edge_score[ graph_offset ][ ir ][ jr ] = mask->pair_score[ mask_offset ][ vir ][ vjr ];
00818                                                 sc_graph->edge_score[ graph_offset ][ ir ][ jr ] = rot_mask_pairscore_get( mask, vi, vj, vir, vjr ) ;
00819                                         }
00820                                 }
00821                         }
00822                 }
00823         }
00824         
00825         
00826         
00827         td_decomp* decomp = td_graph_decomp_tri( sc_graph );
00828         
00829 #ifdef DEBUG
00830         for (int i = 0; i < decomp->bag_count; i++) {
00831                 fprintf(stderr, "bag %d: ", i);
00832                 for (int j = 0; j < decomp->bag_size[i]; j++) {
00833                         fprintf(stderr, "%d ", decomp->bag[i][j] );
00834                 }
00835                 fprintf(stderr, "\n");
00836         }
00837         fprintf(stderr, "Tree root: %d\n", decomp->root);
00838         sc_tree_print(decomp->bag_matrix, decomp->bag_count, decomp->root, -1);
00839 #endif
00840         
00841 #if 1
00842         /*
00843          for (int i = 0; i < decomp->bag_count; i++) {
00844                  printf("%d: ", i);
00845                  for ( int j = 0; j < decomp->bag_size[i]; j++) {
00846                          printf(" %d", vertex_hash[ decomp->bag[i][j] ]);
00847                  }
00848                  printf("\n");
00849          }
00850          */
00851         long *opt_state = td_graph_decomp_solve( decomp );
00852         
00853         for (int i = 0; i < vertex_count; i++) {
00854                 int vi = vertex_hash[i];
00855                 int vir = rot_hash[i][ opt_state[i] ];
00856                 for (int j = 0; j < mask->count[vi]; j++) {
00857                         if ( j == vir ) {
00858                                 mask->state[vi][j] = 1;
00859                         } else {
00860                                 mask->state[vi][j] = 0;
00861                         }                               
00862                 }
00863                 
00864         }
00865         
00866 #else
00867         
00868         //output the graph
00869         printf("graph G {\n");
00870         /*
00871          for (int i = 0; i < mask->len; i++) {
00872                  if ( mask->count[i] != 0) {
00873                          int count = 0;
00874                          int a = 0;
00875                          for (int rot_a = 0; rot_a < mask->count[i]; rot_a++) {
00876                                  if ( mask->state[i][rot_a] ) {
00877                                          if (count == 0)
00878                                                  a = rot_a;
00879                                          count++;
00880                                  }
00881                          }
00882                          //     if ( count > 1 ) {
00883                          for (int j = i+1; j < mask->len; j++) {
00884                                  if ( j != i ) {
00885                                          int offset = i * mask->len + j;//target_hash( i, j );
00886                                          char connect = 0;
00887                                          for (int ir =0; ir < mask->count[i]; ir++) {
00888                                                  if ( mask->state[i][ir])  {
00889                                                          for (int jr =0; jr < mask->count[j]; jr++) {
00890                                                                  if ( mask->state[j][jr])  {
00891                                                                          if (mask->pair_score[ offset ][ ir ][ jr ] > 0 ) {
00892                                                                                  connect = 1;
00893                                                                          }
00894                                                                  }
00895                                                          }
00896                                                  }
00897                                          }
00898                                          if (connect) {
00899                                                  printf("\t%d -- %d;\n", i, j);
00900                                          }
00901                                  }
00902                                  //     }
00903                          }
00904                  }
00905          }              
00906          for (int i = 0; i < mask->len; i++) {
00907                  if ( mask->count[i] != 0) {
00908                          int count = 0;
00909                          int a = 0;
00910                          for (int rot_a = 0; rot_a < mask->count[i]; rot_a++) {
00911                                  if ( mask->state[i][rot_a] ) {
00912                                          if (count == 0)
00913                                                  a = rot_a;
00914                                          count++;
00915                                  }
00916                          }
00917                          if (count == 1 || count == 0) {
00918                                  printf("\t%d [shape=triangle];\n", i);
00919                          }
00920                  }
00921          }
00922          */
00923         
00924         for (int i = 0; i < vertex_count; i++) {
00925                 for (int j = i+1; j < vertex_count; j++) {
00926                         if (graph_matrix[i][j]) {
00927                                 printf("\t%d -- %d\n", i, j);
00928                         }
00929                 }               
00930         }
00931         
00932         /*      
00933                 for (int i = 0; i < decomp->bag_count; i++) {
00934                         printf("\tsubgraph subgraph%d {\n", i);
00935                         for (int j = 0; j < decomp->bag_size[i]; j++) {
00936                                 for ( int k = j+1; k < decomp->bag_size[i]; k++) {
00937                                         if ( graph_matrix[ decomp->bag[i][j] ][ decomp->bag[i][k] ] ) {
00938                                                 printf( "\t\tsub%d_%d -- sub%d_%d;\n", i, decomp->bag[i][j], i, decomp->bag[i][k] );
00939                                         }
00940                                 }
00941                         }
00942                         printf("\t}\n");
00943                 }       
00944          */
00945         for (int i = 0; i < decomp->bag_count; i++) {
00946                 for (int j = 0; j < decomp->bag_count; j++) {
00947                         if ( decomp->bag_matrix[i][j] ) {
00948                                 printf("\tbag%d -- bag%d;\n", i, j);
00949                         }
00950                 }
00951         }
00952         for (int i = 0; i < decomp->bag_count; i++) {
00953                 char buffer[2000];
00954                 strcpy(buffer, "");
00955                 for (int j = 0; j < decomp->bag_size[i]; j++) {
00956                         char buffer2[100];
00957                         sprintf(buffer2, "%d, ", decomp->bag[i][j]);
00958                         strcat(buffer, buffer2);
00959                 }
00960                 printf("\tbag%d [label=\"%s\"];\n", i, buffer);
00961         }
00962         
00963         printf("}\n");  
00964         
00965         
00966 #endif
00967         return 0;
00968 }
00969 
00970 
00971 

Generated on Wed Apr 11 16:50:50 2007 for open_prospect by  doxygen 1.4.6