From acfd73637275c9a92bae971188f178809c45d7d3 Mon Sep 17 00:00:00 2001 From: guruhegde Date: Wed, 23 Oct 2019 23:27:27 +0200 Subject: [PATCH] Handle diverge block targeting return block --- src/LoopSplitter.cpp | 16 +++++++++++++--- src/LoopSplitter.h | 2 ++ src/Util.cpp | 6 ++++++ src/Util.h | 4 ++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/LoopSplitter.cpp b/src/LoopSplitter.cpp index 54ff1ca..120681b 100644 --- a/src/LoopSplitter.cpp +++ b/src/LoopSplitter.cpp @@ -33,19 +33,17 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) { // and connect it to CollectBB, use switch Inst IRBuilder<> Builder(&F->getEntryBlock().front()); auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/); - //auto BrTargetAlloca = Builder.CreateAlloca(Builder.getInt32Ty()); Builder.SetInsertPoint(DistBB); auto IdxVal = Builder.CreateLoad(Idx); auto IdxVal64 = Builder.CreateSExtOrBitCast(IdxVal, Builder.getInt64Ty()); auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64}); auto BrVal = Builder.CreateLoad(BrValPtr); - auto SwitchI = Builder.CreateSwitch(BrVal, BottomHalf); + SwitchI = Builder.CreateSwitch(BrVal, BottomHalf); // XXX We assume now CFG we have is the one after block // predication transformation. SmallVector DivergeBlocks; - SmallVector, 4> TargetBlocks; DivergeBlocks.push_back(TopHalf->getUniquePredecessor()); for (auto & DivergeBB : DivergeBlocks) { @@ -64,6 +62,9 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) { auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64}); Builder.CreateStore(TgtBBVal, BrValPtr); TermI->setSuccessor(1, CollectBB); + + + SwitchI->addCase(BBToId[FalseBB], FalseBB); } } @@ -141,6 +142,15 @@ void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock) errs() << *IndexVar << "\n"; Builder.CreateStore(Builder.getInt32(0), IndexVar); Builder.CreateBr(OldHeader); + + // If FalseBB is terminating instruction, use latch block as target instead. + SmallVector Returns; + getReturnBlocks(F, Returns); + for (auto & Case : SwitchI->cases()) { + if (find(Returns, Case.getCaseSuccessor()) != Returns.end()) { + Case.setSuccessor(L0->getLoopLatch()); + } + } } bool LoopSplitter::run() { diff --git a/src/LoopSplitter.h b/src/LoopSplitter.h index 8f074a5..96191fc 100644 --- a/src/LoopSplitter.h +++ b/src/LoopSplitter.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,7 @@ class LoopSplitter { llvm::Function * F; llvm::LoopInfo * LI; Stats stat; + llvm::SwitchInst * SwitchI; llvm::BasicBlock * ExitBlock; llvm::DenseMap BBToId; diff --git a/src/Util.cpp b/src/Util.cpp index e107f1e..a7b9c4d 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -304,6 +304,12 @@ void cloneBasicBlocksInto(Function * From, Function * To) { CloneFunctionInto(To, From, VMap, From->getSubprogram() != nullptr, Returns); } +void getReturnBlocks(Function * F, SmallVectorImpl & Returns) { + for (BasicBlock & BB : *F) + if (isa(BB.getTerminator())) + Returns.push_back(&BB); +} + void getReturnInstList(Function * F, SmallVectorImpl & Result) { // Check all the return blocks. for (BasicBlock & BB : *F) diff --git a/src/Util.h b/src/Util.h index ed15c9c..7dd3e53 100644 --- a/src/Util.h +++ b/src/Util.h @@ -67,6 +67,10 @@ void setSuccessor(llvm::BasicBlock * BB, llvm::BasicBlock * SuccBB, void cloneBasicBlocksInto(llvm::Function * From, llvm::Function * To); + +void getReturnBlocks(llvm::Function * F, + llvm::SmallVectorImpl & Returns); + void getReturnInstList(llvm::Function * F, llvm::SmallVectorImpl & Result); -- GitLab