00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00023
00024
00025
00026 #include "NueAna/ParticlePID/ParticleAna/NNTrain.h"
00027 #include <TTree.h>
00028 #include <TMath.h>
00029
00030 #include "TMLPAnalyzer.h"
00031 #include <TLegend.h>
00032
00033 #include "TH1F.h"
00034 #include "TH2F.h"
00035 #include "TFile.h"
00036 #include "TChain.h"
00037 #include "TCanvas.h"
00038 #include "TBranch.h"
00039 #include "TLeaf.h"
00040 #include "TRandom.h"
00041
00042 #include "NueAna/ParticlePID/ParticleAna/PRecord.h"
00043 #include "NueAna/NueStandard.h"
00044
00045
00046
00047
00048
00049
00050
00051 NNTrain::NNTrain()
00052 {
00053 for(int i=0;i<3;i++)
00054 trainfile[i].clear();
00055 ResetTrainParams();
00056 mlp=0;
00057 needToWeight=0;
00058
00059
00060 fBaseLine = 735;
00061 fDeltaMS12= 8.0;
00062 fTh12= 0.816;
00063 fTh23=3.141592/4.0;
00064 fDensity=2.75;
00065 SetOscParamBase(0.00243, TMath::ASin(TMath::Sqrt(0.15))/2.0, 0, 1);
00066
00067 nr=new NueRecord();
00068
00069
00070 traincontinue=0;
00071
00072 f2=new TF1("f2","sin(1.267*[0]*[1]/x)*sin(1.267*[0]*[1]/x)",0.,120.);
00073 float delm2=0.0024;
00074 float Ltd=735.;
00075 f2->SetParameters(delm2,Ltd);
00076
00077 }
00078
00079
00080 NNTrain::~NNTrain()
00081 {}
00082
00083
00084 void NNTrain::ResetTrainParams()
00085 {
00086 delta=0;
00087 epsilon=0;
00088 eta=0;
00089 etadecay=0;
00090 tau=0;
00091 }
00092
00093
00094
00095 void NNTrain::MakeTestTree()
00096 {
00097
00098 MakeTrainTree(1);
00099
00100 }
00101
00102
00103
00104 void NNTrain::FillTreePid(string file,string outfile, string MLPfile)
00105 {
00106
00107 TFile * mlpfile = TFile::Open(MLPfile.c_str());
00108 TMultiLayerPerceptron * mlp = 0;
00109 if(mlpfile)
00110 {
00111 mlp=(TMultiLayerPerceptron*)mlpfile->Get("MLP")->Clone();
00112 mlpfile->Close();
00113 }
00114
00115
00116 if(!mlp)
00117 {
00118 printf("unable to load MLP\n");
00119 exit(1);
00120 }
00121
00122
00123 TFile * fin = TFile::Open(file.c_str(), "UPDATE");
00124 if(!fin)
00125 {
00126 printf("no input file specified\n");
00127 exit(1);
00128 }
00129
00130 TTree * tree = (TTree*)fin->Get("TrainTree");
00131 if(!tree)
00132 {
00133 printf("no traintree tree found in input file!\n");
00134 exit(1);
00135 }
00136
00137 TTree * descriptor = (TTree*)fin->Get("descriptor");
00138 if(!descriptor)
00139 {
00140 printf("no descriptor tree found in input file!\n");
00141 exit(1);
00142 }
00143
00144
00145 double pid;
00146 double pars[100];
00147
00148 TFile * fout = TFile::Open(outfile.c_str(),"RECREATE");
00149 fout->cd();
00150 TTree * treeout = (TTree*)tree->CloneTree(0);
00151 TTree * descriptorout = (TTree*)descriptor->CloneTree();
00152
00153 if(treeout->GetBranch("pid"))
00154 treeout->SetBranchAddress("pid",&pid);
00155 else
00156 treeout->Branch("pid",&pid);
00157
00158
00159 tree->SetBranchAddress("pars",&pars);
00160 treeout->SetBranchAddress("pars",&pars);
00161
00162 int ent = tree->GetEntries();
00163 printf("Filling pid for %d entries\n",ent);
00164 for(int i=0;i<ent;i++)
00165 {
00166 tree->GetEntry(i);
00167 pid=mlp->Evaluate(0,pars);
00168 treeout->Fill();
00169 }
00170 treeout->Write("TrainTree",TObject::kOverwrite);
00171 descriptorout->Write();
00172 fin->Close();
00173 fout->Close();
00174
00175 }
00176
00177 TTree * NNTrain::MakeTrainTree(int makeTestTree)
00178 {
00179 int files=0;
00180 for(int i=0;i<3;i++)
00181 files+=trainfile[i].size();
00182 if(files<1)
00183 {
00184 printf("Specify a input file!\n");
00185 return 0;
00186 }
00187
00188 double filePOTs[3];
00189 for(int i=0;i<3;i++)
00190 {
00191 filePOTs[i]=0;
00192 TChain *pt = new TChain("pottree");
00193 for(unsigned int j=0;j<trainfile[i].size();j++)
00194 pt->Add(trainfile[i][j].c_str());
00195
00196 int ent = pt->GetEntries();
00197 for(int j=0;j<ent;j++)
00198 {
00199 pt->GetEntry(j);
00200 filePOTs[i]+= pt->GetLeaf("pot")->GetValue();
00201 }
00202
00203 printf("file type %d has %f POTs\n",i,filePOTs[i]);
00204 }
00205
00206
00207
00208 TChain *c = new TChain("ana_nue");
00209
00210 c->SetBranchAddress("NueRecord",&nr);
00211
00212
00213
00214 c->SetBranchStatus("*",1);
00215
00216 c->SetBranchStatus("shwfit*",0);
00217 c->SetBranchStatus("hitcalc*",0);
00218 c->SetBranchStatus("angcluster*",0);
00219 c->SetBranchStatus("angclusterfit*",0);
00220 c->SetBranchStatus("mstvars*",0);
00221 c->SetBranchStatus("fracvars*",0);
00222 c->SetBranchStatus("subshowervars*",0);
00223 c->SetBranchStatus("highhitvars*",0);
00224 c->SetBranchStatus("shieldrejvars*",0);
00225 c->SetBranchStatus("ann*",0);
00226 c->SetBranchStatus("anainfo*",0);
00227 c->SetBranchStatus("srevent*",0);
00228 c->SetBranchStatus("srshower*",0);
00229 c->SetBranchStatus("srtrack*",0);
00230 c->SetBranchStatus("mctrue*",0);
00231
00232 c->SetBranchStatus("mdadiscri*",0);;
00233 c->SetBranchStatus("treepid*",0);
00234 c->SetBranchStatus("fluxinfo*",0);
00235 c->SetBranchStatus("fluxweights*",0);
00236 c->SetBranchStatus("xsecweights*",0);
00237 c->SetBranchStatus("shi*",0);
00238 c->SetBranchStatus("mri*",0);
00239 c->SetBranchStatus("cdi*",0);
00240 c->SetBranchStatus("timingvars*",0);
00241 c->SetBranchStatus("mcnnv*",0);
00242 c->SetBranchStatus("dtree*",0);
00243
00244 c->SetBranchStatus("precord*",0);
00245 c->SetBranchStatus("precordMRCC*",0);
00246
00247
00248 c->SetBranchStatus("precord.particles*",1);
00249 c->SetBranchStatus("precord.event*",1);
00250
00251 c->SetBranchStatus("fracvars.fract_2_planes",1);
00252 c->SetBranchStatus("fracvars.fract_4_planes",1);
00253 c->SetBranchStatus("fracvars.fract_6_planes",1);
00254 c->SetBranchStatus("fracvars.fract_8_counters",1);
00255 c->SetBranchStatus("fracvars.fract_road",1);
00256
00257 c->SetBranchStatus("shwfit.par_a",1);
00258 c->SetBranchStatus("shwfit.par_b",1);
00259 c->SetBranchStatus("shwfit.LongE",1);
00260 c->SetBranchStatus("shwfit.uv_molrad_peak_9s_2pe_dw",1);
00261 c->SetBranchStatus("shwfit.uv_rms_9s_2pe_dw",1);
00262
00263 c->SetBranchStatus("mstvars.e4w",1);
00264 c->SetBranchStatus("mstvars.o4w",1);
00265
00266 c->SetBranchStatus("fluxweights.totbeamweight",1);
00267 c->SetBranchStatus("fluxweights.RPtotbeamweight",1);
00268
00269 c->SetBranchStatus("mctrue.fNueClass",1);
00270 c->SetBranchStatus("mctrue.fOscProb",1);
00271 c->SetBranchStatus("mctrue.resonanceCode",1);
00272 c->SetBranchStatus("mctrue.nuEnergy",1);
00273 c->SetBranchStatus("mctrue.emShowerFraction",1);
00274 c->SetBranchStatus("mctrue.trueVisibleE",1);
00275 c->SetBranchStatus("mctrue.nuFlavor",1);
00276 c->SetBranchStatus("mctrue.nonOscNuFlavor",1);
00277 c->SetBranchStatus("mctrue.interactionType",1);
00278
00279 c->SetBranchStatus("shwfit.contPlaneCount050",1);
00280 c->SetBranchStatus("shwfit.UVSlope",1);
00281 c->SetBranchStatus("srtrack.endX",1);
00282 c->SetBranchStatus("srtrack.endY",1);
00283 c->SetBranchStatus("srtrack.endZ",1);
00284 c->SetBranchStatus("srtrack.vtxX",1);
00285 c->SetBranchStatus("srtrack.vtxY",1);
00286 c->SetBranchStatus("srtrack.vtxZ",1);
00287
00288 c->SetBranchStatus("srtrack.endPlane",1);
00289 c->SetBranchStatus("srtrack.begPlane",1);
00290 c->SetBranchStatus("srtrack.trklikePlanes",1);
00291 c->SetBranchStatus("srevent.tracks",1);
00292 c->SetBranchStatus("srevent.largestEvent",1);
00293 c->SetBranchStatus("srevent.showers",1);
00294 c->SetBranchStatus("srevent.phNueGeV",1);
00295 c->SetBranchStatus("srevent.vtxX",1);
00296 c->SetBranchStatus("srevent.vtxY",1);
00297 c->SetBranchStatus("srevent.vtxZ",1);
00298
00299
00300
00301
00302
00303 Double_t pid=0;
00304 TMultiLayerPerceptron *mlp=0;
00305 if(makeTestTree==1)
00306 {
00307 TFile *fmlp = TFile::Open("MLP.root");
00308 if(fmlp)mlp = (TMultiLayerPerceptron*) fmlp->Get("MLP")->Clone();
00309 if(!mlp)
00310 {
00311 printf("NO MLP FOUND when making test tree\n");
00312
00313
00314 }
00315 if(fmlp)fmlp->Close();
00316
00317 }
00318
00319
00320
00321 TFile *fout=TFile::Open(inputFile.c_str(),"RECREATE");
00322 fout->cd();
00323 TTree *sim = new TTree("TrainTree","TrainTree");
00324
00325
00326 sim->Branch("mctrue_oscprob",&mctrue_oscprob);
00327 sim->Branch("mctrue_totbeamweight",&mctrue_totbeamweight);
00328 sim->Branch("type",&type);
00329 sim->Branch("mctrue_fNueClass",&mctrue_type);
00330 sim->Branch("weight",&weight);
00331 sim->Branch("pars",&pars,"pars[30]/D");
00332
00333 sim->Branch("mctrue_nuEnergy",&mctrue_nuenergy);
00334 sim->Branch("visenergy",&event_visenergy);
00335 sim->Branch("resonanceCode",&mctrue_iresonance);
00336 sim->Branch("tweight",&tweight);
00337 sim->Branch("emfrac",&emfrac);
00338
00339 Int_t preselect;
00340 sim->Branch("strict_preselection",&preselect);
00341
00342 if(mlp && makeTestTree==1)
00343 {
00344 sim->Branch("pid",&pid);
00345 }
00346
00347
00348 for(int i=0;i<20;i++)pars[i]=0;
00349
00350 for(unsigned int k=0;k<3;k++)
00351 for(unsigned int i=0;i<trainfile[k].size();i++)
00352 {
00353 int added =c->Add(trainfile[k][i].c_str());
00354 if(!added)
00355 {
00356 printf("input file %s bad\n",trainfile[k][i].c_str());
00357 return 0;
00358 }else{
00359 printf("added %d files from %s\n",added,trainfile[k][i].c_str());
00360 }
00361 }
00362
00363
00364 c->GetEntry(0);
00365 int det = nr->GetHeader().GetVldContext().GetDetector();
00366 int mc = nr->GetHeader().GetVldContext().GetSimFlag();
00367 printf("detector %d mc %d\n",det,mc);
00368
00369
00370
00371 NueStandard::SetDefaultOscParam();
00372 NueStandard::SetOscNoMatter();
00373
00374
00375
00376 double oscparams[20];
00377 NueStandard::GetOscParam(oscparams);
00378 if(det==Detector::kFar)
00379 for(int oi = 0; oi < int(OscPar::kUnknown); oi++){
00380 printf("oscpar %d: %f\n",oi,oscparams[oi]);
00381 }
00382
00383
00384
00385 int npars=0;
00386
00387 int ent = c->GetEntries();
00388 printf("%d total entries\n",ent);
00389
00390
00391 double type_weight[5];
00392 for(int i=0;i<5;i++)type_weight[i]=1;
00393 if(!makeTestTree && det==Detector::kFar && mc==SimFlag::kMC)
00394 {
00395 for(int i=0;i<5;i++)type_weight[i]=0;
00396 for(int i=0;i<ent;i++)
00397 {
00398 c->GetEntry(i);
00399 if(nr->mctrue.fNueClass<0||nr->mctrue.fNueClass>4)continue;
00400
00401
00402 double w = nr->mctrue.fOscProb*nr->fluxweights.totbeamweight;
00403 if(w>type_weight[nr->mctrue.fNueClass])type_weight[nr->mctrue.fNueClass]=w;
00404 }
00405 for(int i=0;i<5;i++)printf("type_weight %d %f\n",i,type_weight[i]);
00406
00407 }
00408
00409
00410 int failedpval=0;
00411 int failedpp=0;
00412 int failedspp=0;
00413 int faileddq=0;
00414
00415
00416
00417
00418
00419 int require_nue_presel=0;
00420 int require_pid_presel=0;
00421 int require_strict_pid_presel=1;
00422 int checkNueVars=0;
00423 int checkParticleVars=1;
00424
00425 for(int i=0;i<ent;i++)
00426 {
00427 c->GetEntry(i);
00428 if(i%10000==0)printf("%2.2f%%...\n",100.*i/ent);
00429
00430 det = nr->GetHeader().GetVldContext().GetDetector();
00431 mc = nr->GetHeader().GetVldContext().GetSimFlag();
00432
00433 if(mc!=SimFlag::kMC)
00434 nr->mctrue.fNueClass=0;
00435
00436 mctrue_type=nr->mctrue.fNueClass;
00437
00438
00439 if(mctrue_type<0||mctrue_type>4)continue;
00440
00441
00442
00443 mctrue_nuenergy = nr->mctrue.nuEnergy;
00444 mctrue_iresonance = nr->mctrue.resonanceCode;
00445 event_visenergy = nr->precord.event.visenergy;
00446 emfrac = nr->mctrue.emShowerFraction;
00447
00450
00451
00452 double pots=0;
00453 if(det==Detector::kNear)pots=filePOTs[0];
00454 else
00455 {
00456
00457
00458 if(nr->mctrue.interactionType==0)
00459 {
00460 pots=filePOTs[0]+filePOTs[1]+filePOTs[2];
00461 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00462 abs(nr->mctrue.nuFlavor)==14)
00463 {
00464 pots=filePOTs[0];
00465 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00466 abs(nr->mctrue.nuFlavor)==12)
00467 {
00468 pots=filePOTs[0];
00469 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00470 abs(nr->mctrue.nuFlavor)==14)
00471 {
00472 pots=filePOTs[2];
00473 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00474 abs(nr->mctrue.nuFlavor)==12)
00475 {
00476 pots=filePOTs[2];
00477 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00478 abs(nr->mctrue.nuFlavor)==16)
00479 {
00480 pots=filePOTs[1];
00481 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00482 abs(nr->mctrue.nuFlavor)==16)
00483 {
00484 pots=filePOTs[1];
00485 }
00486 }
00487
00488
00489
00490 if(!pots)continue;
00491
00492
00493
00494 mctrue_totbeamweight= (mc==SimFlag::kData) ? 1 :
00495 NueStandard::GetRPWBeamWeight(nr);
00496
00497 if(det == Detector::kNear && mc == SimFlag::kData)
00498 if(!NueStandard::PassesDataQuality(nr)){faileddq++;continue;}
00499
00500
00503
00504
00505
00508
00509 mctrue_oscprob=1;
00510
00511 if(det == Detector::kFar && mc == SimFlag::kMC)
00512 {
00513 mctrue_oscprob = (nr->mctrue.fNueClass==0) ? 1 :
00514
00515
00516
00517 NueStandard::GetOscWeight(nr->mctrue.nuFlavor,nr->mctrue.nonOscNuFlavor,nr->mctrue.nuEnergy);
00518 if(mctrue_oscprob<-1000)continue;
00519 }
00520
00521
00524
00525
00526
00527
00530
00531
00532 tweight = 0;
00533
00534 if(det == Detector::kFar)
00535 {
00536 if(nr->mctrue.interactionType==0)
00537 {
00538 tweight=filePOTs[0]/(filePOTs[0]+filePOTs[2]);
00539 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00540 abs(nr->mctrue.nuFlavor)==14)
00541 {
00542 tweight=mctrue_oscprob;
00543 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00544 abs(nr->mctrue.nuFlavor)==12)
00545 {
00546 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00547 abs(nr->mctrue.nuFlavor)==14)
00548 {
00549 tweight=mctrue_oscprob;
00550 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00551 abs(nr->mctrue.nuFlavor)==12)
00552 {
00553 tweight=mctrue_oscprob/0.075;
00554 }else if (abs(nr->mctrue.nonOscNuFlavor)==14&&
00555 abs(nr->mctrue.nuFlavor)==16)
00556 {
00557 }else if (abs(nr->mctrue.nonOscNuFlavor)==12&&
00558 abs(nr->mctrue.nuFlavor)==16)
00559 {
00560 }
00561 }
00562
00563
00564
00565 if(!makeTestTree && det==Detector::kFar && mc == SimFlag::kMC)
00566 {
00567
00568
00569
00570
00571 if( tweight<gRandom->Uniform())continue;
00572 }
00573
00574
00577
00578
00579
00580
00581 weight=3.25*(1e8)/pots;
00582
00583
00584
00585
00586
00587 PRecord * pr = &nr->precord;
00588
00589
00590 trueEMFrac = nr->mctrue.emShowerFraction;
00591
00592
00593 if(trueEMFrac>1)trueEMFrac=1;
00594
00595
00596
00597
00598
00600
00601
00602 int pass_precord_preselect=1;
00603 preselect=1;
00604
00605 length_z = pr->event.max_z - pr->event.min_z;
00606 if(pr->particles.totvise==0)pass_precord_preselect=0;
00607 largest_frac=pr->particles.totvise ? pr->particles.largest_particle_e/pr->particles.totvise : 1;
00608
00609 ntot_lsps=pr->particles.ntot*pr->particles.longest_s_particle_s;
00610
00611
00612 prim_ae0=pr->particles.prim_par_e0 ? pr->particles.prim_par_a/pr->particles.prim_par_e0 : 0;
00613
00614
00616
00617 if(pr->event.inFiducial!=1){pass_precord_preselect=0;preselect=0;}
00618 if(pr->event.contained!=1){pass_precord_preselect=0;preselect=0;}
00619 if(pr->particles.ntot<1){pass_precord_preselect=0;preselect=0;}
00620
00621
00622
00623 if(pr->particles.longest_s_particle_s>2)pass_precord_preselect=0;
00624 if(pr->particles.longest_s_particle_s<0.0 )pass_precord_preselect=0;
00625
00626 if(pr->event.max_z-pr->event.min_z>2)pass_precord_preselect=0;
00627 if(pr->event.max_z-pr->event.min_z<0.0)pass_precord_preselect=0;
00628
00629 if(pr->event.visenergy/25.<0.0 )pass_precord_preselect=0;
00630 if(pr->event.visenergy/25.>10.)pass_precord_preselect=0;
00631
00632
00633 if(!pass_precord_preselect&&require_pid_presel){failedpp++;continue;}
00634
00635
00636
00637 if(pr->particles.longest_s_particle_s>1.2)preselect=0;
00638 if(pr->particles.longest_s_particle_s<0.1)preselect=0;
00639
00640 if(pr->event.max_z-pr->event.min_z>1.2)preselect=0;
00641 if(pr->event.max_z-pr->event.min_z<0.1)preselect=0;
00642
00643 if(pr->event.visenergy/25.<0.5)preselect=0;
00644 if(pr->event.visenergy/25.>8)preselect=0;
00645
00646 if(!preselect&&require_strict_pid_presel){failedspp++;continue;}
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660 int pass = 1;
00661
00662
00663
00664 pass = pass && NueStandard::IsInFid(nr);
00665 pass = pass && NueStandard::PassesPreSelection(nr);
00666
00667 if(!pass && require_nue_presel)continue;
00668
00669
00670
00671
00672
00673
00674 if (nr->shwfit.par_a<-1000) pass=0;
00675 if (nr->shwfit.par_b<-1000)pass=0;
00676 if (nr->shwfit.uv_molrad_peak_9s_2pe_dw<-1000)pass=0;
00677 if (nr->shwfit.uv_rms_9s_2pe_dw<-1000) pass=0;
00678 if (nr->mstvars.e4w<-1000) pass=0;
00679 if (nr->mstvars.e4w>500)pass=0;
00680 if (nr->mstvars.o4w<-1000)pass=0;
00681 if (nr->mstvars.o4w>500) pass=0;
00682 if (nr->fracvars.fract_2_planes<-1000)pass=0;
00683 if (nr->fracvars.fract_4_planes<-1000)pass=0;
00684 if (nr->fracvars.fract_6_planes<-1000)pass=0;
00685 if (nr->fracvars.fract_8_counters<-1000) pass=0;
00686 if (nr->fracvars.fract_road<-1000) pass=0;
00687 if (nr->shwfit.LongE<-1000)pass=0;
00688 if (nr->shwfit.LongE>1000)pass=0;
00689 if(!pass && checkNueVars)continue;
00690
00691
00692 mstvar_combine = nr->mstvars.e4w+nr->mstvars.o4w;
00693
00694
00695
00696
00697
00700 pass=1;
00701 if(pr->particles.longest_s_particle_s<0 || pr->particles.longest_s_particle_s>6)pass=0;
00702
00703
00704
00705 if(pr->particles.longest_z<0 || pr->particles.longest_z>6)pass=0;
00706
00707
00708 if(pr->particles.ntot<0 || pr->particles.ntot>50)pass=0;
00709
00710
00711
00712
00713
00714 if(pr->particles.rms_r<0 || pr->particles.rms_r>100)pass=0;
00715
00716
00717 if(pr->particles.prim_par_e0<0 || pr->particles.prim_par_e0>40e3)pass=0;
00718 if(pr->particles.prim_par_chisq<0 || pr->particles.prim_par_chisq>1000)pass=0;
00719 if(pr->particles.largest_particle_peakdiff<-200 || pr->particles.largest_particle_peakdiff>200)pass=0;
00720
00721
00722
00723
00724
00725
00726
00727
00728
00729
00730
00731
00732
00733
00734
00735
00736
00737
00738
00739
00740
00741
00742
00743 if(length_z<0 || length_z>6)pass=0;
00744
00745
00746
00747
00748 if(nr->shwfit.par_b<0 || nr->shwfit.par_b>6)pass=0;
00749
00750
00751
00752 if(mstvar_combine<-1000 || nr->shwfit.par_b>1000)pass=0;
00753
00754
00755
00756
00757
00758 if(nr->shwfit.LongE<0 || nr->shwfit.LongE>1200)pass=0;
00759
00760 if(!pass && checkParticleVars){failedpval++;continue;}
00761
00762
00763
00764
00765
00766 isdis=0;
00767 if(nr->mctrue.resonanceCode==1003)isdis=1;
00768 isndis=isdis?0:1;
00769
00770
00771 type=1;
00772 mctrue_type=nr->mctrue.fNueClass;
00773 if(mctrue_type!=2)type=0;
00774
00775 isnue=mctrue_type==2;
00776 isnc=mctrue_type==0;
00777 iscc=mctrue_type==1;
00778 istau=mctrue_type==3;
00779 isbeamve=mctrue_type==4;
00780
00781
00782
00783
00784 largest_cmp_chisqndf = pr->particles.largest_particle_cmp_ndf ? pr->particles.largest_particle_cmp_chisq / pr->particles.largest_particle_cmp_ndf : 0;
00785
00786 if(largest_cmp_chisqndf>3)largest_cmp_chisqndf=3;
00787
00788
00789
00790
00791
00792
00793 for(int i=0;i<11;i++)
00794 pars[i]=0;
00795
00796
00797
00798 int z=0;
00799
00800
00801
00802
00803
00804
00805
00806
00807
00808
00809
00810
00811
00812
00813
00814
00815 pars[z++]=pr->particles.longest_s_particle_s;
00816
00817 pars[z++]=pr->particles.mol_rad_r;
00818
00819 pars[z++]=pr->particles.emfrac;
00820 pars[z++]=pr->particles.ntot;
00821 pars[z++]=pr->particles.weighted_phi;
00822 pars[z++]=largest_frac;
00823
00824
00825
00826 pars[z++]=pr->particles.prim_par_b;
00827 pars[z++]=pr->particles.prim_par_e0;
00828 pars[z++]=pr->particles.prim_par_chisq;
00829
00830 pars[z++]=largest_cmp_chisqndf;
00831
00832
00833
00834 pars[z++]=pr->particles.prim_par_a;
00835
00836 pars[z++]=pr->event.nclusters;
00837 pars[z++]=prim_ae0;
00838
00839
00840
00841
00842
00843
00844
00845
00846
00847
00848
00849
00850 npars=z;
00851
00852
00853 if(makeTestTree==1)
00854 {
00855
00856 if(mlp)pid=mlp->Evaluate(0,pars);
00857 }
00858
00859 sim->Fill();
00860 }
00861
00862
00863 printf("failed %d on dq\n",faileddq);
00864 printf("failed %d on pvals\n",failedpval);
00865 printf("failed %d on particle pre\n",failedpp);
00866 printf("failed %d on strict particle pre\n",failedspp);
00867
00868 sim->Write("TrainTree", TObject::kOverwrite);
00869 TTree * descriptor = new TTree("descriptor","descriptor");
00870 descriptor->Branch("npars",&npars);
00871 descriptor->Fill();
00872 descriptor->Write("descriptor", TObject::kOverwrite);
00873
00874
00875 fout->Close();
00876 return sim;
00877 }
00878
00879
00880
00881 void NNTrain::Train(int steps, int update, int method, string form,double ncemcut,double veemcut)
00882 {
00883 double pots=6.5*1e8*99;
00884
00885
00886 this->ncemcut=ncemcut;
00887 this->veemcut=veemcut;
00888
00889 TTree *simold=0;
00890 TTree *descriptorin=0;
00891 TFile * fin = TFile::Open(inputFile.c_str());
00892 if(fin)simold=(TTree*)fin->Get("TrainTree");
00893 if(fin)descriptorin=(TTree*)fin->Get("descriptor");
00894
00895
00896
00897 if(!fin)
00898 {
00899 simold = MakeTrainTree();
00900 TFile * fin = TFile::Open(inputFile.c_str());
00901 if(fin)simold=(TTree*)fin->Get("TrainTree");
00902 }
00903
00904
00905
00906 if(!simold)
00907 {
00908 printf("sim tree missing\n");
00909 exit(1);
00910 }
00911
00912 if(!descriptorin)
00913 {
00914 printf("descriptor tree missing\n");
00915 exit(1);
00916 }
00917
00918 TFile * f = new TFile("nntemp1.root","RECREATE");
00919 f->cd();
00920
00921
00922
00923 TTree * sim=0;
00924 TTree * descriptor = descriptorin->CloneTree();
00925
00926 if(needToWeight)
00927 {
00928 printf("need to reweight this tree...\n");
00929 sim=simold->CloneTree(0);
00930 SetBranches(sim);
00931 SetBranches(simold);
00932
00933 int ent=simold->GetEntries();
00934 printf("looking over %d entries...\n",ent);
00935 int sav=0;
00936 int stype[5];
00937 double sweight[5];
00938 for(int i=0;i<5;i++){stype[i]=0;sweight[i]=0;}
00939
00940
00941 double type_weight[5];
00942 for(int i=0;i<5;i++)type_weight[i]=0;
00943
00944 for(int i=0;i<ent;i++)
00945 {
00946 simold->GetEntry(i);
00947
00948 if(mctrue_type<0||mctrue_type>4)continue;
00949
00950
00951 double w = mctrue_oscprob*mctrue_totbeamweight;
00952 if(w>type_weight[mctrue_type])type_weight[mctrue_type]=w;
00953 }
00954 for(int i=0;i<5;i++)printf("type_weight %d %f\n",i,type_weight[i]);
00955
00956
00957 for(int i=0;i<ent;i++)
00958 {
00959 simold->GetEntry(i);
00960
00961
00962
00963
00964
00965
00966 if( tweight<gRandom->Uniform())continue;
00967
00968 sim->Fill();
00969 sav++;
00970 stype[mctrue_type]++;
00971 sweight[mctrue_type]+=tweight;
00972
00973 }
00974 printf("using %d entries...\n",sav);
00975 for(int i=0;i<5;i++)
00976 {
00977 printf("type: %d cnt: %d sumweight %f\n",i,stype[i],sweight[i]);
00978 }
00979 }else{
00980 sim=simold->CloneTree(-1,"fast");
00981 SetBranches(sim);
00982 }
00983
00984
00985 char structure[1000];
00986
00987 int npars = (int)descriptor->GetLeaf("npars")->GetValue();
00988
00989 sprintf(structure,"%s","");
00990 for(int i=0;i<npars-1;i++)sprintf(structure,"%s@pars[%d], ",structure,i);
00991 sprintf(structure,"%s@pars[%d] ",structure,npars-1);
00992 sprintf(structure,"%s:%s:type",structure,form.c_str());
00993
00994 if(mlp)delete mlp;mlp=0;
00995 mlp = new TMultiLayerPerceptron(structure,sim,"(Entry$+1)%2","(Entry$)%2");
00996
00997
00998 if(method==1)mlp->SetLearningMethod(TMultiLayerPerceptron::kSteepestDescent);
00999 if(method==2) mlp->SetLearningMethod(TMultiLayerPerceptron::kStochastic);
01000 if(method==3) mlp->SetLearningMethod(TMultiLayerPerceptron::kBatch);
01001
01002 if(method==4)mlp->SetLearningMethod(TMultiLayerPerceptron::kRibierePolak);
01003 if(method==5)mlp->SetLearningMethod(TMultiLayerPerceptron::kFletcherReeves);
01004 if(method==6)mlp->SetLearningMethod(TMultiLayerPerceptron::kBFGS);
01005
01006
01007
01008
01009
01010
01011
01012
01013
01014
01015
01016
01017
01018
01019
01020
01021 char tmpa[200];
01022 #ifdef dodraw
01023 sprintf(tmpa,"+,text,graph,update=%d",update);
01024 #else
01025 sprintf(tmpa,"+,text,update=%d",update);
01026 #endif
01027
01028
01029 if(!traincontinue)
01030 {
01031 mlp->Randomize();
01032 mlp->DumpWeights("ini.out");
01033 }
01034 mlp->LoadWeights("ini.out");
01035
01036
01037 mlp->Train(steps, tmpa);
01038
01039 mlp->DumpWeights("mlp.out");
01040
01041
01042 printf("training done\n");
01043
01044
01045
01046
01047 #ifdef dodraw
01048 TCanvas* mlpa_canvas = new TCanvas("mlpa_canvas","Network analysis");
01049 mlpa_canvas->Divide(2,2);
01050 TMLPAnalyzer ana(mlp);
01051
01052 ana.GatherInformations();
01053
01054 ana.CheckNetwork();
01055 mlpa_canvas->cd(1);
01056
01057 ana.DrawDInputs();
01058 mlpa_canvas->cd(2);
01059
01060 mlp->Draw();
01061 mlpa_canvas->cd(3);
01062
01063 ana.DrawNetwork(0,"type==1","type==0");
01064 mlpa_canvas->cd(4);
01065 #endif
01066
01067
01068
01069
01070 double wantpots= 3.25;
01071 double potscale = (wantpots*1e8)/pots;
01072
01073 printf("b\n");
01074
01075 int npidbins=250;
01076
01077 TH1F *bg = new TH1F("bgh", "NN output", npidbins, -.5, 1.5);
01078 TH1F *sign = new TH1F("sigh", "NN output", npidbins, -.5, 1.5);
01079
01080 TH2F *bg_recoE = new TH2F("bg_recoE", "bg_recoE", npidbins, -.5, 1.5, 20,0,10);
01081 TH2F *sig_recoE = new TH2F("sig_recoE", "sig_recoE", npidbins, -.5, 1.5, 20,0,10);
01082
01083
01084 bg->SetDirectory(0);
01085 sign->SetDirectory(0);
01086
01087
01088 Double_t pid=0;
01089 Double_t oscweight=0;
01090
01091
01092
01093
01094
01095
01096
01097
01098 printf("storing %d entries\n",(int)sim->GetEntries());
01099 TBranch * pidbranch = sim->Branch("pid",&pid);
01100 TBranch * oscweightbranch = sim->Branch("oscweight",&oscweight);
01101
01102
01103
01104
01105
01106
01107
01108
01109
01110 for (int i = 0; i < sim->GetEntries(); i++) {
01111
01112 sim->GetEntry(i);
01113
01114
01115 pid=mlp->Evaluate(0, pars);
01116
01117
01118 if(type==0)bg->Fill(pid,mctrue_oscprob * mctrue_totbeamweight * weight);
01119 if(type==1)sign->Fill(pid,mctrue_oscprob * mctrue_totbeamweight * weight);
01120
01121 if(type==0) bg_recoE->Fill(pid,event_visenergy/25.,mctrue_oscprob * mctrue_totbeamweight * weight);
01122 if(type==1) sig_recoE->Fill(pid,event_visenergy/25.,mctrue_oscprob * mctrue_totbeamweight * weight);
01123
01124 oscweight=mctrue_oscprob * mctrue_totbeamweight * potscale;
01125 oscweightbranch->Fill();
01126
01127
01128
01129 pidbranch->Fill();
01130 }
01131
01132
01133
01134
01135
01136 sim->Write("", TObject::kOverwrite);
01137
01138
01139 bg->SetLineColor(kBlue);
01140 bg->SetFillStyle(3008); bg->SetFillColor(kBlue);
01141 sign->SetLineColor(kRed);
01142 sign->SetFillStyle(3003); sign->SetFillColor(kRed);
01143 bg->SetStats(0);
01144 sign->SetStats(0);
01145 #ifdef dodraw
01146 bg->Draw();
01147 sign->Draw("same");
01148 TLegend *legend = new TLegend(.75, .80, .95, .95);
01149 legend->AddEntry(bg, "Background");
01150 legend->AddEntry(sign, "Signal");
01151 legend->Draw();
01152 mlpa_canvas->cd(0);
01153 #endif
01154
01155
01156
01157
01158 sign->Write();
01159 bg->Write();
01160
01161
01162 char tmp[200];
01163 sprintf(tmp,"fom @ %2.2f pots",wantpots);
01164
01165 TH1F * hz = new TH1F("fom",tmp,npidbins,-0.5,1.5);
01166 for(int i=0;i<npidbins;i++)
01167 {
01168 double d = bg->Integral(i,npidbins);
01169 if(d>0)
01170 {
01171 double n = sign->Integral(i,npidbins);
01172 hz->SetBinContent(i,n/sqrt(d));
01173
01174 }
01175
01176 }
01177 hz->Write();
01178
01179
01180 double bg_sys_frac=0.1;
01181 double sig_sys_frac=0.1;
01182
01183 sprintf(tmp,"multibin superfom @ %2.2f pots",wantpots);
01184 TH1F * hz1 = new TH1F("fom_mb_super",tmp,npidbins,-0.5,1.5);
01185 for(int i=0;i<npidbins;i++)
01186 {
01187 double d = bg->Integral(i,npidbins);
01188 if(d<=0)continue;
01189 {
01190
01191
01192 double sum_sigbg=0;
01193 double sum_sig=0;
01194 double sum_bg=0;
01195 double sum_sigsig=0;
01196
01197 TH1D *sE = sig_recoE->ProjectionY("sE",i,npidbins);
01198 TH1D *bE = bg_recoE->ProjectionY("bE",i,npidbins);
01199
01200 for(int k=1;k<sE->GetNbinsX()+1;k++)
01201 {
01202 double sig=sE->GetBinContent(k);
01203 double nbg=bE->GetBinContent(k);
01204 sum_sig+=sig;
01205 sum_bg+=nbg;
01206 sum_sigbg+=nbg*nbg*bg_sys_frac*bg_sys_frac;
01207 sum_sigsig+=sig*sig*sig_sys_frac*sig_sys_frac;
01208 }
01209
01210 double orig_sum_fom1=0;
01211 if(sum_bg)orig_sum_fom1 = sum_sig/sqrt(sum_sig+sum_bg+sum_sigbg+sum_sigsig);
01212
01213
01214
01215 hz1->SetBinContent(i,orig_sum_fom1);
01216
01217 }
01218
01219 }
01220 hz1->Write();
01221
01222
01223
01224 sprintf(tmp,"fom @ %2.2f pots",wantpots);
01225
01226 TH1F * hz2 = new TH1F("superfom",tmp,npidbins,-0.5,1.5);
01227 for(int i=0;i<npidbins;i++)
01228 {
01229 double d = bg->Integral(i,npidbins);
01230 if(d>0)
01231 {
01232 double n = sign->Integral(i,npidbins);
01233 hz2->SetBinContent(i,!d ? 0 : n/sqrt(n+d+d*d*bg_sys_frac*bg_sys_frac+n*n*sig_sys_frac*sig_sys_frac));
01234
01235 }
01236
01237 }
01238 hz2->Write();
01239
01240
01241 sprintf(tmp,"fom @ %2.2f pots",wantpots);
01242
01243 TH1F *hz2a = new TH1F("simplesuperfom",tmp,npidbins,-0.5,1.5);
01244 for(int i=0;i<npidbins;i++)
01245 {
01246 double d = bg->Integral(i,npidbins);
01247 if(d>0)
01248 {
01249 double n = sign->Integral(i,npidbins);
01250 hz2a->SetBinContent(i,!d ? 0 : n/sqrt(d+d*d*bg_sys_frac*bg_sys_frac));
01251
01252 }
01253
01254 }
01255 hz2a->Write();
01256
01257
01258
01259
01260
01261 int mb = hz->GetMaximumBin();
01262
01263 printf("Max fom at %2.2f pots is %f sig %f back %f cut above %f\n",wantpots,hz->GetMaximum(),sign->Integral(mb,npidbins),bg->Integral(mb,npidbins),hz->GetBinLowEdge(mb));
01264
01265 int mb1 = hz1->GetMaximumBin();
01266
01267 printf("Max multibin super fom at %2.2f pots is %f sig %f back %f cut above %f\n",wantpots,hz1->GetMaximum(),sign->Integral(mb1,npidbins),bg->Integral(mb1,npidbins),hz1->GetBinLowEdge(mb1));
01268
01269 int mb2 = hz2->GetMaximumBin();
01270
01271 printf("Max super fom at %2.2f pots is %f sig %f back %f cut above %f\n",wantpots,hz2->GetMaximum(),sign->Integral(mb2,npidbins),bg->Integral(mb2,npidbins),hz2->GetBinLowEdge(mb2));
01272
01273 printf("Max standard super fom at %2.2f pots is %f sig %f back %f cut above %f\n",wantpots,hz2a->GetMaximum(),sign->Integral(mb2,npidbins),bg->Integral(mb2,npidbins),hz2a->GetBinLowEdge(mb2));
01274
01275
01276 TFile * fmlp = TFile::Open("MLP.root","RECREATE");
01277 fmlp->cd();
01278 mlp->Write("MLP");
01279
01280
01281 #ifdef dodraw
01282 mlpa_canvas->SaveAs("trainImage.eps");
01283 #endif
01284
01285
01286 fmlp->Close();
01287 if(fin)fin->Close();
01288 f->Close();
01289
01290
01291
01292 }
01293
01294 void NNTrain::SetBranches(TTree *sim)
01295 {
01296 sim->SetMakeClass(1);
01297 sim->SetBranchStatus("*",1);
01298 sim->SetBranchAddress("weight",&weight);
01299 sim->SetBranchAddress("type",&type);
01300 sim->SetBranchAddress("pars",&pars);
01301 sim->SetBranchAddress("mctrue_oscprob",&mctrue_oscprob);
01302 sim->SetBranchAddress("mctrue_fNueClass",&mctrue_type);
01303 sim->SetBranchAddress("mctrue_totbeamweight",&mctrue_totbeamweight);
01304 sim->SetBranchAddress("mctrue_nuEnergy",&mctrue_nuenergy);
01305 sim->SetBranchAddress("visenergy",&event_visenergy);
01306 sim->SetBranchAddress("resonanceCode",&mctrue_iresonance);
01307 sim->SetBranchAddress("tweight",&tweight);
01308
01309 }
01310
01311
01312 void NNTrain::SetOscParamBase( float dm2, float ss13,
01313 float delta, int hierarchy){
01314
01315 Double_t dm2_12 = fDeltaMS12*1e-5;
01316 Double_t dm2_23 = dm2;
01317
01318 Double_t par[9] = {0};
01319 par[OscPar::kL] = fBaseLine;
01320 par[OscPar::kTh23] = fTh23;
01321 par[OscPar::kTh12] = fTh12;
01322 par[OscPar::kTh13] = ss13;
01323 par[OscPar::kDeltaM23] = hierarchy*dm2_23;
01324 par[OscPar::kDeltaM12] = dm2_12;
01325 par[OscPar::kDensity] = fDensity;
01326 par[OscPar::kDelta] = delta;
01327 par[OscPar::kNuAntiNu] = 1;
01328
01329
01330 fOscCalc.SetOscParam(par);
01331 }
01332
01333
01334
01335
01336 double NNTrain::osc(double nuEnergy, int interactionType, int nonOscFlavor, int oscFlavor)
01337 {
01338
01339
01340
01341 if(nuEnergy<0)nuEnergy=-nuEnergy;
01342 if(oscFlavor<0)oscFlavor=-oscFlavor;
01343 if(nonOscFlavor<0)nonOscFlavor=-nonOscFlavor;
01344
01345 if (interactionType==0){
01346 return OscillationProb(f2,0,nuEnergy);
01347 }
01348 if (nonOscFlavor==14&&oscFlavor==14){
01349 return OscillationProb(f2,1,nuEnergy);
01350 }
01351 if (nonOscFlavor==12&&oscFlavor==12){
01352 return OscillationProb(f2,4,nuEnergy);
01353
01354 }
01355 if (nonOscFlavor==12&&oscFlavor==14){
01356 return OscillationProb(f2,6,nuEnergy);
01357 }
01358 if (nonOscFlavor==14&&oscFlavor==12){
01359 return OscillationProb(f2,2,nuEnergy);
01360 }
01361 if (nonOscFlavor==14&&oscFlavor==16){
01362 return OscillationProb(f2,3,nuEnergy);
01363
01364 }
01365 if (nonOscFlavor==12&&oscFlavor==16){
01366 return OscillationProb(f2,5,nuEnergy);
01367
01368 }
01369
01370
01371 return 0;
01372 }
01373
01374
01375 float NNTrain::OscillationProb(TF1* f2, int ntype, float NuE, float sinth23, float sin2th13) {
01376
01377
01378
01379
01380 float OscProb = 0 ;
01381 float NumuToNutau ;
01382 float NumuToNue ;
01383 float NueSurvival ;
01384 float NumuSurvival ;
01385 float NueToNutau ;
01386 float NueToNumu;
01387
01388 if(NuE<0)NuE=-NuE;
01389
01390 if (ntype==0)
01391 {
01392 OscProb = 1 ;
01393 return OscProb;
01394 }
01395
01396
01397
01398
01399 if (ntype==4)
01400 {
01401 NueSurvival = 1.- sin2th13*f2->Eval(NuE) ;
01402 OscProb = NueSurvival ;
01403 return OscProb;
01404 }
01405
01406 if (ntype==5)
01407 {
01408 NueToNutau = (1.-sinth23)*sin2th13*f2->Eval(NuE) ;
01409 OscProb = NueToNutau ;
01410 return OscProb;
01411 }
01412
01413
01414 NumuToNue = sinth23*sin2th13*f2->Eval(NuE) ;
01415 NueToNumu = NumuToNue;
01416 if (ntype==6){ OscProb = NueToNumu; return OscProb;}
01417 if (ntype==2){ OscProb = NumuToNue; return OscProb;}
01418
01419
01420 NumuToNutau = 4.*sinth23*(1.-sinth23)*pow(1-sin2th13/4,2) ;
01421 NumuToNutau *= f2->Eval(NuE) ;
01422
01423 if (ntype==3){ OscProb = NumuToNutau; return OscProb;}
01424
01425
01426
01427 if (ntype==1 )
01428 {
01429 NumuSurvival = 1. - NumuToNutau - NumuToNue ;
01430 OscProb = NumuSurvival ;
01431
01432 return OscProb;
01433 }
01434
01435
01436
01437
01438
01439
01440
01441
01442
01443
01444
01445
01446
01447
01448
01449
01450
01451
01452
01453
01454
01455
01456 return OscProb ;
01457 }
01458
01459