diff --git a/src/LoopSplitter.cpp b/src/LoopSplitter.cpp index 54ff1ca0f7f1c41ec0cbe8a0e7cc0eea03a54ebc..120681be917ff50cf9f179ecac43d532781972e7 100644 --- a/src/LoopSplitter.cpp +++ b/src/LoopSplitter.cpp @@ -33,19 +33,17 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) { // and connect it to CollectBB, use switch Inst IRBuilder<> Builder(&F->getEntryBlock().front()); auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/); - //auto BrTargetAlloca = Builder.CreateAlloca(Builder.getInt32Ty()); Builder.SetInsertPoint(DistBB); auto IdxVal = Builder.CreateLoad(Idx); auto IdxVal64 = Builder.CreateSExtOrBitCast(IdxVal, Builder.getInt64Ty()); auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64}); auto BrVal = Builder.CreateLoad(BrValPtr); - auto SwitchI = Builder.CreateSwitch(BrVal, BottomHalf); + SwitchI = Builder.CreateSwitch(BrVal, BottomHalf); // XXX We assume now CFG we have is the one after block // predication transformation. SmallVector DivergeBlocks; - SmallVector, 4> TargetBlocks; DivergeBlocks.push_back(TopHalf->getUniquePredecessor()); for (auto & DivergeBB : DivergeBlocks) { @@ -64,6 +62,9 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) { auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64}); Builder.CreateStore(TgtBBVal, BrValPtr); TermI->setSuccessor(1, CollectBB); + + + SwitchI->addCase(BBToId[FalseBB], FalseBB); } } @@ -141,6 +142,15 @@ void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock) errs() << *IndexVar << "\n"; Builder.CreateStore(Builder.getInt32(0), IndexVar); Builder.CreateBr(OldHeader); + + // If FalseBB is terminating instruction, use latch block as target instead. + SmallVector Returns; + getReturnBlocks(F, Returns); + for (auto & Case : SwitchI->cases()) { + if (find(Returns, Case.getCaseSuccessor()) != Returns.end()) { + Case.setSuccessor(L0->getLoopLatch()); + } + } } bool LoopSplitter::run() { diff --git a/src/LoopSplitter.h b/src/LoopSplitter.h index 8f074a52aafae1de4abd46698eccd267a054bceb..96191fccd9ef47df6f995159358efdda1b7acde4 100644 --- a/src/LoopSplitter.h +++ b/src/LoopSplitter.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,7 @@ class LoopSplitter { llvm::Function * F; llvm::LoopInfo * LI; Stats stat; + llvm::SwitchInst * SwitchI; llvm::BasicBlock * ExitBlock; llvm::DenseMap BBToId; diff --git a/src/Util.cpp b/src/Util.cpp index e107f1ef25e5b012a1b789db3d8b2d59f8643854..a7b9c4d381246209be9f55c5f788fb1d20987f82 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -304,6 +304,12 @@ void cloneBasicBlocksInto(Function * From, Function * To) { CloneFunctionInto(To, From, VMap, From->getSubprogram() != nullptr, Returns); } +void getReturnBlocks(Function * F, SmallVectorImpl & Returns) { + for (BasicBlock & BB : *F) + if (isa(BB.getTerminator())) + Returns.push_back(&BB); +} + void getReturnInstList(Function * F, SmallVectorImpl & Result) { // Check all the return blocks. for (BasicBlock & BB : *F) diff --git a/src/Util.h b/src/Util.h index ed15c9c2c129bb6d3f37e9f8bea0863b9ba1b4c7..7dd3e534b02f8bd93b1344c70fb109d3736050cf 100644 --- a/src/Util.h +++ b/src/Util.h @@ -67,6 +67,10 @@ void setSuccessor(llvm::BasicBlock * BB, llvm::BasicBlock * SuccBB, void cloneBasicBlocksInto(llvm::Function * From, llvm::Function * To); + +void getReturnBlocks(llvm::Function * F, + llvm::SmallVectorImpl & Returns); + void getReturnInstList(llvm::Function * F, llvm::SmallVectorImpl & Result);