Commit acfd7363 authored by guruhegde's avatar guruhegde

Handle diverge block targeting return block

parent d14eada2
......@@ -33,19 +33,17 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) {
// and connect it to CollectBB, use switch Inst
IRBuilder<> Builder(&F->getEntryBlock().front());
auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/);
//auto BrTargetAlloca = Builder.CreateAlloca(Builder.getInt32Ty());
Builder.SetInsertPoint(DistBB);
auto IdxVal = Builder.CreateLoad(Idx);
auto IdxVal64 = Builder.CreateSExtOrBitCast(IdxVal, Builder.getInt64Ty());
auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64});
auto BrVal = Builder.CreateLoad(BrValPtr);
auto SwitchI = Builder.CreateSwitch(BrVal, BottomHalf);
SwitchI = Builder.CreateSwitch(BrVal, BottomHalf);
// XXX We assume now CFG we have is the one after block
// predication transformation.
SmallVector<BasicBlock *, 4> DivergeBlocks;
SmallVector<std::pair<Value *, BasicBlock *>, 4> TargetBlocks;
DivergeBlocks.push_back(TopHalf->getUniquePredecessor());
for (auto & DivergeBB : DivergeBlocks) {
......@@ -64,6 +62,9 @@ void LoopSplitter::addAdapterBasicBlocks(Instruction * SP, Value * Idx) {
auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64});
Builder.CreateStore(TgtBBVal, BrValPtr);
TermI->setSuccessor(1, CollectBB);
SwitchI->addCase(BBToId[FalseBB], FalseBB);
}
}
......@@ -141,6 +142,15 @@ void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock)
errs() << *IndexVar << "\n";
Builder.CreateStore(Builder.getInt32(0), IndexVar);
Builder.CreateBr(OldHeader);
// If FalseBB is terminating instruction, use latch block as target instead.
SmallVector<BasicBlock *, 4> Returns;
getReturnBlocks(F, Returns);
for (auto & Case : SwitchI->cases()) {
if (find(Returns, Case.getCaseSuccessor()) != Returns.end()) {
Case.setSuccessor(L0->getLoopLatch());
}
}
}
bool LoopSplitter::run() {
......
......@@ -6,6 +6,7 @@
#include <llvm/ADT/SmallVector.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
......@@ -34,6 +35,7 @@ class LoopSplitter {
llvm::Function * F;
llvm::LoopInfo * LI;
Stats stat;
llvm::SwitchInst * SwitchI;
llvm::BasicBlock * ExitBlock;
llvm::DenseMap<const llvm::BasicBlock *, llvm::ConstantInt *> BBToId;
......
......@@ -304,6 +304,12 @@ void cloneBasicBlocksInto(Function * From, Function * To) {
CloneFunctionInto(To, From, VMap, From->getSubprogram() != nullptr, Returns);
}
void getReturnBlocks(Function * F, SmallVectorImpl<BasicBlock *> & Returns) {
for (BasicBlock & BB : *F)
if (isa<ReturnInst>(BB.getTerminator()))
Returns.push_back(&BB);
}
void getReturnInstList(Function * F, SmallVectorImpl<ReturnInst *> & Result) {
// Check all the return blocks.
for (BasicBlock & BB : *F)
......
......@@ -67,6 +67,10 @@ void setSuccessor(llvm::BasicBlock * BB, llvm::BasicBlock * SuccBB,
void cloneBasicBlocksInto(llvm::Function * From, llvm::Function * To);
void getReturnBlocks(llvm::Function * F,
llvm::SmallVectorImpl<llvm::BasicBlock *> & Returns);
void getReturnInstList(llvm::Function * F,
llvm::SmallVectorImpl<llvm::ReturnInst *> & Result);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment