Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

MdaDiscrimAna.cxx

Go to the documentation of this file.
00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 #include <iostream>
00015 using std::cout;
00016 using std::endl;
00017 
00018 #include <fstream>
00019 using std::ifstream;
00020 
00021 #include <string>
00022 using std::string;
00023 
00024 #include "MessageService/MsgService.h"
00025 #include "NueAna/MdaDiscrimAna.h"
00026 #include "NueAna/StringUtil.h"
00027 #include "NueAna/NueRecord.h"
00028 #include "NueAna/NueAnaTools/NueConvention.h"
00029 #include "AnalysisNtuples/ANtpDefaultValue.h"
00030 
00031 #include "TList.h"
00032 #include "TClass.h"
00033 #include "TDataType.h"
00034 #include "TDataMember.h"
00035 #include "TRealData.h"
00036 
00037 ClassImp(MdaDiscrimAna)
00038 
00039 CVSID("$Id: MdaDiscrimAna.cxx,v 1.9 2007/03/01 16:38:49 rhatcher Exp $");
00040 
00041 Int_t MdaDiscrimAna::nClass;
00042 std::deque <TMatrixD> MdaDiscrimAna::quadCoeff;
00043 std::deque <TVectorD> MdaDiscrimAna::linCoeff;
00044 std::deque <TVectorD> MdaDiscrimAna::constCoeff;
00045 std::deque <TVectorD> MdaDiscrimAna::meanVec;   
00046 
00047 std::deque < std::deque <std::string> > MdaDiscrimAna::typeClassVec;
00048 std::deque < std::string> MdaDiscrimAna::typeCoeffVec;
00049 std::deque < std::string> MdaDiscrimAna::varNameVec;
00050     
00051 MdaDiscrimAna::MdaDiscrimAna(NueRecord &nr, MdaDiscrim &md):
00052     nueRec(nr),
00053     fMdaDiscrim(md)
00054 {
00055 
00056 }
00057 
00058 MdaDiscrimAna::~MdaDiscrimAna()
00059 {
00060 
00061 }
00062 
00063 
00064 void MdaDiscrimAna::Analyze()
00065 {
00066 
00067     MSG("MdaDiscrimAna",Msg::kDebug)<<"In MdaDiscrimAna::Analyze"<<endl;
00068 
00069     //Fill SAS Calibration Arrays
00070     if (!isFillDone()) isFilled = FillCalibArrays();
00071     
00072     //Fill P-Vector values from NueRecord
00073     FillPVector();
00074     
00075     //Calculate nue PID from discriminator function
00076     PIDCalc();
00077     
00078     //Classify events according to a probability threshold cut  
00079     MdaClassify();
00080             
00081 }
00082 
00083 Bool_t MdaDiscrimAna::FillCalibArrays()
00084 {
00085     
00086     //Reads the SAS coefficient input file 
00087     //(containing quadratic, linear and constant coeffs)
00088     //and fills the necessary matrices and vectors
00089     //Also fills a vector 
00090 
00091     static const Int_t kLineSize=2000;
00092 
00093     ifstream sasInput;
00094     sasInput.open(sasFileName.c_str());
00095     if (sasInput.fail()) {
00096         
00097         MSG("MdaDiscrimAna",Msg::kInfo)<< "SAS coefficient file not found! " 
00098                                         << "Not calculating MDA PID..." 
00099                                         << endl;
00100         return false;
00101     }
00102 
00103     Char_t oneLine[kLineSize];
00104 
00105     string sepaCol=" ";
00106     vector<string> oneVec;
00107          
00108     //Read the first line (mean values) and get the number 
00109     //of columns in the file
00110     sasInput.getline(oneLine,kLineSize);
00111     string topLine=oneLine;
00112     StringUtil::SplitString(topLine,sepaCol,oneVec);
00113     
00114     //Read the mean values into a vector
00115     TVectorD meanTemp(oneVec.size()-2);
00116     for(UInt_t col=2; col < oneVec.size(); col++){
00117         meanTemp(col-2)=(atof((oneVec.at(col)).c_str()));
00118     }
00119     
00120     for(Int_t m=0; m < meanTemp.GetNoElements(); m++){
00121         MSG("MdaDiscrimAna",Msg::kDebug)<< "Mean:" 
00122                                         << m 
00123                                         << "="
00124                                         << meanTemp(m) 
00125                                         << endl; 
00126     }
00127 
00128     meanVec.push_back(meanTemp);
00129 
00130     //Declare matrices and vectors for coefficient reading
00131     TMatrixD quadTemp(oneVec.size()-2, oneVec.size()-2);
00132     
00133     TVectorD linTemp(oneVec.size()-2);
00134     
00135     TVectorD constTemp(oneVec.size()-2);
00136 
00137     deque <string> tempClassVec;
00138 
00139     //Count the rows for each matrix
00140     Int_t rowCounter=0;
00141     
00142     //Reset the number of classes counter
00143     nClass=0;
00144 
00145     //Reset the ifstream pointer to the beginning of the file
00146     // sasInput.seekg(0,ios::beg);
00147     
00148     //Reset the line vector
00149     oneVec.clear();
00150     
00151     while (! sasInput.eof() ){
00152         sasInput.getline(oneLine,kLineSize);
00153         topLine=oneLine;
00154         StringUtil::SplitString(topLine,sepaCol,oneVec);
00155         
00156         string coeffType;            
00157         
00158         for(UInt_t col=0; col < oneVec.size(); col++){
00159             
00160             if (col==0) {
00161             
00162                 tempClassVec.push_back(oneVec.at(col));
00163                 
00164             }else if (col==1){ 
00165                 
00166                 coeffType=oneVec.at(col);
00167                 
00168                 if(coeffType!="_LINEAR_" && coeffType !="_CONST_"){
00169                     if (varNameVec.size()<(oneVec.size()-2)) 
00170                         varNameVec.push_back(oneVec.at(col));
00171                     typeCoeffVec.push_back("_QUAD_");
00172                 }else typeCoeffVec.push_back(oneVec.at(col));
00173              
00174             }else{
00175              
00176                 if (coeffType=="_LINEAR_"){
00177                     
00178                     linTemp(col-2)=(atof((oneVec.at(col)).c_str()));
00179                  
00180                 }else if (coeffType=="_CONST_"){
00181                 
00182                     constTemp(col-2)=(atof((oneVec.at(col)).c_str()));
00183                     
00184                 }else{
00185                     quadTemp(rowCounter,(col-2))
00186                         =(atof((oneVec.at(col)).c_str()));
00187                 }
00188             }
00189         }
00190         oneVec.clear();
00191         rowCounter++;
00192         
00193         if (coeffType=="_CONST_"){
00194             nClass++;
00195             rowCounter=0;
00196             typeClassVec.push_back(tempClassVec);
00197             quadCoeff.push_back(quadTemp);
00198             linCoeff.push_back(linTemp);
00199             constCoeff.push_back(constTemp);            
00200         
00201             tempClassVec.clear();
00202 
00203         }else continue;
00204     }
00205     
00206     if(sasInput) sasInput.close();
00207 
00208     return true;
00209 }
00210 
00211 void MdaDiscrimAna::FillPVector()
00212 {
00213     //Now use varNameVec to fill the variable values p-vector
00214     //Shameless rip-off of CompareAll::FillFromList
00215 
00216     NueRecord *nueR=&nueRec;
00217     
00218     if(varNameVec.size() == 0) return;
00219     TString hname;
00220     UInt_t count = 0;
00221     
00222     TClass *cl;
00223     TRealData *rd;
00224     string vName;
00225     TDataMember *member;
00226     TDataType *membertype;
00227     Double_t value = 0.0;
00228     Int_t   valueI = 0;
00229     
00230     cl=nueR->IsA();
00231     TIter  next(cl->GetListOfRealData());                                                                                 
00232     TVectorD valueTemp(varNameVec.size());
00233 
00234     while ((rd =dynamic_cast<TRealData*>(next()))) {
00235         member = rd->GetDataMember();
00236         membertype = member->GetDataType();
00237         vName=rd->GetName();
00238         
00239         TString vNameOrig=vName;
00240 
00241         //Replace "." with "_" in the ntuple variable names
00242         if (vName.length() > vName.find_first_of("."))
00243             vName.replace(vName.find_first_of("."),1,"_");
00244         
00245         Int_t offset = rd->GetThisOffset();
00246         Char_t *pointer = reinterpret_cast<Char_t*>(nueR)  + offset;
00247         
00248         for(UInt_t i = 0; i < varNameVec.size();i++){
00249             MSG("MdaDiscrimAna",Msg::kDebug)<<"Found variable "
00250                                             << "NueRec: " << vName 
00251                                             << " SAS: " << varNameVec.at(i) 
00252                                             << endl;
00253             if(vName == varNameVec.at(i)){
00254               value = ANtpDefVal::kDouble;
00255               valueI = ANtpDefVal::kInt;
00256               if(!NeedsSpecialAttention(vNameOrig, nueR, value))
00257 
00258               if (!strcmp(membertype->GetTypeName(),"float") ||
00259                   !strcmp(membertype->GetTypeName(),"Float_t") || 
00260                   !strcmp(membertype->GetTypeName(),"double")  ||
00261                   !strcmp(membertype->GetTypeName(),"Double_t")){
00262                 value=atof(membertype->AsString(pointer));
00263                 valueI=1;
00264               }
00265               else if(!strcmp(membertype->GetTypeName(),"int") ||
00266                       !strcmp(membertype->GetTypeName(),"Int_t")){
00267                 value=atoi(membertype->AsString(pointer));
00268                 valueI=atoi(membertype->AsString(pointer));
00269               }
00270               
00271               else MSG("MdaDiscrimAna",Msg::kWarning)<<"Found variable "
00272                                                    << "NueRec: " << vName
00273                                                    << " of unknown type "
00274                                                    << membertype->GetTypeName()
00275                                                    << endl ;
00276                 
00277               MSG("MdaDiscrimAna",Msg::kDebug)<<"Found variable "
00278                                              <<vName
00279                                              <<" with value "
00280                                              <<value
00281                                              <<endl;
00282                 
00283                 if(!ANtpDefVal::IsDefault(value) && 
00284                    !ANtpDefVal::IsDefault(valueI)){
00285                     
00286                     valueTemp(i)=value;
00287 
00288                 }else {
00289                     valueTemp(i)=(meanVec.at(0))(i);
00290                 }
00291                 MSG("MdaDiscrimAna",Msg::kDebug)<<"Found variable "
00292                                                 <<vName<<" with value "
00293                                                 <<value
00294                                                 <<endl;                
00295                 count++;
00296                 i = varNameVec.size();
00297             }
00298         }
00299         if(count == varNameVec.size()) break;
00300     }
00301     
00302     varPVector.push_back(valueTemp);
00303 
00304     return;     
00305     
00306 }
00307 
00308 bool MdaDiscrimAna::isFilled;
00309 
00310 void MdaDiscrimAna::MdaClassify()
00311 {
00312 
00313     //Use the calculated class probabilities and user supplied threshold cut
00314     //to perform posterior event classification
00315     
00316     //Determine the maximum probability and its index
00317 
00318     if(probClass.size() < static_cast<UInt_t>(nClass)) return;
00319 
00320     IterDeqDouble_t maxIter =
00321         max_element(probClass.begin(),probClass.end());
00322 
00323     Int_t maxPos= distance(probClass.begin(),maxIter);
00324     MSG("MdaDiscrimAna",Msg::kDebug)<< "Prob Max found for entry " 
00325                                     << maxPos << ": " 
00326                                     << probClass.at(maxPos) 
00327                                     << " with class " 
00328                                     << (typeClassVec.at(maxPos)).at(0)
00329                                     << endl;
00330     
00331     //Classify into class ()
00332     Double_t probMax=probClass.at(maxPos);
00333     string classMax=(typeClassVec.at(maxPos)).at(0);
00334     
00335     if (probMax < threshCut) fMdaDiscrim.fMdaClass=-1;
00336     else if (classMax=="nue") fMdaDiscrim.fMdaClass=ClassType::nue;
00337     else if (classMax=="ncu") fMdaDiscrim.fMdaClass=ClassType::NC;
00338     else if (classMax=="num") fMdaDiscrim.fMdaClass=ClassType::numu;
00339     else if (classMax=="nut") fMdaDiscrim.fMdaClass=ClassType::nutau;
00340     else {
00341         MSG("MdaDiscrimAna",Msg::kError)
00342              <<"Unknown class in SAS calibration set" << endl;
00343         fMdaDiscrim.fMdaClass=ANtpDefVal::kInt;
00344     }
00345     return;
00346 
00347 }
00348 
00349 
00350 Bool_t MdaDiscrimAna::NeedsSpecialAttention(TString name
00351                                             , NueRecord *nr
00352                                             , Double_t &value)
00353 {
00354    
00355     //All the fHeaders and four of the MST vars require special effort
00356     if(name == "fHeader.fSnarl") {
00357         value = nr->GetHeader().GetSnarl();
00358     }if(name == "fHeader.fRun") {
00359         value = nr->GetHeader().GetRun();
00360     }if(name == "fHeader.fSubRun") {
00361         value = nr->GetHeader().GetSubRun();
00362     }if(name == "fHeader.fEvtNo") {
00363          value = nr->GetHeader().GetEventNo();
00364     }if(name == "fHeader.fEvents") {
00365         value = nr->GetHeader().GetEvents();
00366     }if(name == "fHeader.fTrackLength") {
00367         value = nr->GetHeader().GetTrackLength();
00368     }
00369 
00370     if(name == "mstvars.eallw1") {
00371         if(nr->mstvars.enn1 > 0) value = 0.0;
00372         for(int i=0;i<nr->mstvars.enn1;i++){
00373             value += nr->mstvars.eallw1[i];
00374         }
00375     }
00376     if(name == "mstvars.oallw1") {
00377         if(nr->mstvars.onn1 > 0) value = 0.0;
00378         for(int i=0;i<nr->mstvars.onn1;i++){
00379             value += nr->mstvars.oallw1[i];
00380         }
00381     } 
00382     if(name == "mstvars.eallm1") {
00383         if(nr->mstvars.enn1 > 0) value = 0.0;
00384         for(int i=0;i<nr->mstvars.enn1;i++){
00385             value += nr->mstvars.eallm1[i];
00386         }
00387     }
00388     if(name == "mstvars.oallm1") {
00389         if(nr->mstvars.onn1 > 0) value = 0.0;
00390         for(int i=0;i<nr->mstvars.onn1;i++){
00391             value += nr->mstvars.oallm1[i];
00392         }
00393     }    
00394     
00395     if(value > -9999) return true;
00396     return false;
00397 }
00398 
00399 
00400 void MdaDiscrimAna::PIDCalc()
00401 {
00402 
00403     DeqDouble_t discrimValue;
00404     
00405     //Calculate discriminant function Q for each class
00406     for (Int_t n=0; n < nClass; n++){
00407         
00408         //Create an auxiliary copy of the Pvector
00409         TVectorD varPClone=(varPVector.at(0));
00410 
00411         //Note that varPClone is modified by the "*=" operator
00412         Double_t quadTerm=((varPClone)*=(quadCoeff.at(n)))*(varPVector.at(0));
00413 
00414         Double_t linTerm=(linCoeff.at(n))*(varPVector.at(0));
00415 
00416         Double_t constTerm=(constCoeff.at(n))(0);
00417         
00418         discrimValue.push_back(quadTerm+linTerm+constTerm);
00419         
00420         MSG("MdaDiscrimAna",Msg::kDebug)<< "Q("
00421                                         << (typeClassVec.at(n)).at(0) 
00422                                         <<") = " 
00423                                         << discrimValue.at(n) <<endl;
00424     }
00425      
00426     TVectorD tempDiff(nClass-1);
00427 
00428     //Calculate discriminator difference vectors
00429     //P(a) = e^a/(e^(a)+e^(b)+e^(c)+...) = 1/(1+e^(b-a)+e^(c-a)+...)
00430     //Avoids e^X where X is big => machine-dep floating exception.
00431   
00432     for (Int_t n=0; n < nClass; n++){
00433 
00434         Int_t diffCount=0;
00435         
00436         for (Int_t m=0; m < nClass; m++){
00437             if(m!=n){
00438                 tempDiff(diffCount)=(discrimValue.at(m)-discrimValue.at(n));
00439                 diffCount++;
00440             } else continue;   
00441         }
00442 
00443         if (TMath::Abs(tempDiff.Min())>600. || TMath::Abs(tempDiff.Max())>600){
00444             MSG("MdaDiscrimAna",Msg::kDebug)<<"Bypass floating exception "
00445                                             << endl;
00446             continue;
00447         }else discrimDiff.push_back(tempDiff); 
00448 
00449         tempDiff.Zero();
00450     }
00451   
00452     if(discrimDiff.size()!=static_cast<UInt_t>(nClass)) return;
00453 
00454     //Calculate probability value for each class
00455     for (Int_t n=0; n < nClass; n++){
00456         
00457         Double_t denomSum=0.0;
00458         
00459         for (Int_t m=0; m < (nClass-1); m++){            
00460             denomSum+=TMath::Exp((discrimDiff.at(n))(m));
00461         }
00462 
00463         probClass.push_back(1.0/(1.0+denomSum));
00464 
00465         MSG("MdaDiscrimAna",Msg::kDebug) << "P("
00466                                          << (typeClassVec.at(n)).at(0) 
00467                                          <<") = " 
00468                                          << probClass.at(n) <<endl;
00469         
00470         if((typeClassVec.at(n)).at(0)=="nue") 
00471             fMdaDiscrim.fMdaPIDnue=probClass.at(n);        
00472         if((typeClassVec.at(n)).at(0)=="ncu") 
00473             fMdaDiscrim.fMdaPIDnc=probClass.at(n);        
00474         if((typeClassVec.at(n)).at(0)=="num") 
00475             fMdaDiscrim.fMdaPIDnumu=probClass.at(n);        
00476         if((typeClassVec.at(n)).at(0)=="nut") 
00477             fMdaDiscrim.fMdaPIDnutau=probClass.at(n);        
00478     }
00479     
00480     return;
00481 }
00482 
00483 
00484 void MdaDiscrimAna::Reset()
00485 {
00486  
00487 
00488 }

Generated on Mon Feb 15 11:06:58 2010 for loon by  doxygen 1.3.9.1