Commit 743cf699 authored by guruhegde's avatar guruhegde

Reimplement LoopSplitter logic

parent 52b5bb83
......@@ -12,6 +12,7 @@ void IRLoop::extractLoopSkeleton(Loop * L) {
Latch = L->getLoopLatch();
IdxAlloca = getLoopIndexVar(L);
PreHeader = L->getLoopPreheader();
ExitBlock = Header->getTerminator()->getSuccessor(1);
}
void IRLoop::analyze(Loop * L) {
......@@ -23,10 +24,7 @@ void IRLoop::analyze(Loop * L) {
}
}
void IRLoop::constructEmptyLoop(AllocaInst * TripCount, BasicBlock * ExitBlock) {
auto & Ctx = ExitBlock->getContext();
const auto & F = ExitBlock->getParent();
void IRLoop::constructEmptyLoop(AllocaInst * TripCount, Function * F) {
Latch = BasicBlock::Create(Ctx, "latch", F);
Header = BasicBlock::Create(Ctx, "header", F, Latch);
PreHeader = BasicBlock::Create(Ctx, "preheader", F, Header);
......@@ -56,7 +54,7 @@ void IRLoop::constructEmptyLoop(AllocaInst * TripCount, BasicBlock * ExitBlock)
auto IdxVal = Builder.CreateLoad(IdxAlloca);
auto TC = Builder.CreateLoad(TripCount);
auto * Icmp = Builder.CreateICmpSLT(IdxVal, TC, "loop-predicate");
Builder.CreateCondBr(Icmp, EmptyBody, ExitBlock);
Builder.CreateCondBr(Icmp, EmptyBody, EmptyBody);
}
void IRLoop::setLoopBlocks(std::vector<BasicBlock *> & BlockList) {
......
......@@ -7,24 +7,28 @@
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include "Util.h"
namespace tas {
// This is similar to Loop object available in LoopInfo object
// but simple.
class IRLoop {
llvm::LLVMContext & Ctx;
llvm::BasicBlock * PreHeader;
llvm::BasicBlock * Header;
llvm::BasicBlock * Latch;
llvm::BasicBlock * EmptyBody; // Useful for creating empty loop.
llvm::AllocaInst * IdxAlloca;
llvm::SmallVector<llvm::BasicBlock *, 4> Blocks;
llvm::BasicBlock * ExitBlock;
public:
IRLoop() = default;
IRLoop(llvm::LLVMContext & C) : Ctx(C) {}
void analyze(llvm::Loop * L);
void constructEmptyLoop(llvm::AllocaInst * TripCount,
llvm::BasicBlock * InsertAfter);
llvm::Function * F);
void extractLoopSkeleton(llvm::Loop * L);
void setLoopBlocks(std::vector<llvm::BasicBlock *> & Blocks);
......@@ -36,6 +40,22 @@ public:
return Header;
}
llvm::BasicBlock * getLatch() {
return Latch;
}
llvm::BasicBlock * getExitBlock() {
return ExitBlock;
}
void setExitBlock(llvm::BasicBlock * ExitBB) {
setSuccessor(Header, ExitBB, 1);
}
bool contains(llvm::BasicBlock * BB) {
return std::find(Blocks.begin(), Blocks.end(), BB) != Blocks.end();
}
void printLooopInfo() {
llvm::errs() << "LoopInfo:";
Header->printAsOperand(llvm::errs());
......
......@@ -20,189 +20,130 @@ using namespace llvm;
namespace tas {
void LoopSplitter::addAdapterBasicBlocks(Loop * L, Instruction * SP, Value * Idx) {
auto TopHalf = SP->getParent();
auto BottomHalf = TopHalf->splitBasicBlock(SP);
auto CollectBB = BasicBlock::Create(F->getContext(), "collector", F, BottomHalf);
auto DistBB = BasicBlock::Create(F->getContext(), "distributor", F, BottomHalf);
BranchInst::Create(DistBB, CollectBB);
setSuccessor(TopHalf, CollectBB);
setSuccessor(CollectBB, DistBB);
LoopSplitEdgeBlocks.push_back(CollectBB);
IRBuilder<> Builder(&F->getEntryBlock().front());
auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/);
Builder.SetInsertPoint(DistBB);
auto IdxVal = Builder.CreateLoad(Idx);
auto IdxVal64 = Builder.CreateSExtOrBitCast(IdxVal, Builder.getInt64Ty());
auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64});
auto BrVal = Builder.CreateLoad(BrValPtr);
SwitchI = Builder.CreateSwitch(BrVal, BottomHalf);
writeToAsmFile(*F->getParent());
SmallVector<BasicBlock *, 4> DivergeBlocks;
auto DivergeBlock = TopHalf->getUniquePredecessor();
if (DivergeBlock != L->getHeader()) {
DivergeBlocks.push_back(DivergeBlock);
}
for (auto & DivergeBB : DivergeBlocks) {
auto TermI = cast<BranchInst>(DivergeBB->getTerminator());
auto Cond = TermI->getCondition();
Builder.SetInsertPoint(TermI);
auto FalseBB = TermI->getSuccessor(1);
auto TgtBBVal = Builder.CreateSelect(Cond,
BBToId[TermI->getSuccessor(0)],
BBToId[FalseBB]);
void LoopSplitter::addBatchArrayForIntermediateVars(Loop * L0) {
auto Idx = getLoopIndexVar(L0);
IRBuilder<> Builder (F->getContext());
for (auto & AV : AnnotatedVars) {
auto arrayPtr = createArray(F, cast<AllocaInst>(AV)->getAllocatedType(), 32);
auto NumUses = AV->getNumUses();
while (NumUses > 0) {
User * U = AV->user_back();
if (cast<Instruction>(U)->getParent() == &F->getEntryBlock()) {
NumUses--;
continue;
}
auto IdxVal = Builder.CreateLoad(Idx);
auto IdxVal64 = Builder.CreateSExtOrBitCast(IdxVal, Builder.getInt64Ty());
auto BrValPtr = Builder.CreateGEP(BrTgtArray, {Builder.getInt64(0), IdxVal64});
Builder.CreateStore(TgtBBVal, BrValPtr);
TermI->setSuccessor(1, CollectBB);
SwitchI->addCase(BBToId[FalseBB], FalseBB);
Builder.SetInsertPoint(cast<Instruction>(U));
auto ptr = Builder.CreateGEP(arrayPtr, {Builder.getInt64(0), Builder.CreateLoad(Idx)});
U->replaceUsesOfWith(AV, ptr);
NumUses--;
}
}
}
bool LoopSplitter::prepareForLoopSplit(Function *F, Loop * L0, Stats & stat) {
auto Idx = getLoopIndexVar(L0);
auto AnnotatedVars = detectExpPtrVars(F);
auto VarUsePoints = detectExpPtrUses(AnnotatedVars);
// Add unique id to each basic block
unsigned i = 0;
IRBuilder<> Builder(F->getContext());
auto SetIntValForBB = [&] (const auto & BB) {
BBToId.insert(std::make_pair(&BB, Builder.getInt32(++i)));
};
for_each(*F, SetIntValForBB);
for_each(VarUsePoints,
[&] (auto & VarUse) { insertLLVMPrefetchIntrinsic(F, VarUse); });
for_each(VarUsePoints,
[&] (auto & VarUse) { addAdapterBasicBlocks(L0, VarUse, Idx); });
bool LoopSplitter::prepareForLoopSplit(Function *F, Stats & stat) {
AnnotatedVars = detectExpPtrVars(F);
VarUsePoints = detectExpPtrUses(AnnotatedVars);
for (int i = 0; i < VarUsePoints.size(); ++i) {
auto & VarUse = VarUsePoints[i];
// VarUse is not first instruction
auto FirstI = VarUse->getParent()->begin();
assert (&*FirstI != VarUse && "VarUse is not first instruction in basic block");
auto PrevI = VarUse->getPrevNode();
Instruction * SplitPoint = nullptr;
if (isa<GetElementPtrInst>(PrevI)) {
SplitPoint = PrevI;
}
auto Prefetch = insertLLVMPrefetchIntrinsic(F, VarUse, SplitPoint);
EndBlocks.push_back(Prefetch->getParent());
Prefetch->getParent()->splitBasicBlock(Prefetch->getNextNode());
SplitPoints.push_back(SplitPoint ? SplitPoint : VarUse);
}
stat.AnnotatedVarsSize = AnnotatedVars.size();
stat.VarUsePointsSize = VarUsePoints.size();
return VarUsePoints.size() != 0;
}
void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock) {
auto OldHeader = L0->getHeader();
auto OldEntry = cast<BranchInst>(OldHeader->getTerminator())->getSuccessor(0);
auto PreLoopBB = getPreLoopBlock(L0);
auto PostLoopBB = cast<BranchInst>(OldHeader->getTerminator())->getSuccessor(1);
auto MidBlock = SplitBlock->getUniqueSuccessor();
bool LoopSplitter::run() {
bool changed = prepareForLoopSplit(F, stat);
if (!changed) return false;
auto LBT = LoopBodyTraverser(L0);
DominatorTree DT(*F);
LoopInfo LI(DT);
// Collect Blocks in range [OldEntry, MidBlock)
vector<BasicBlock *> TopLoopBlocks;
LBT.traverse(TopLoopBlocks, OldEntry, MidBlock);
// If no loops, we are done.
if (LI.begin() == LI.end()) return false;
// Collect Blocks in range [MidBlock, Latch)
vector<BasicBlock *> BottomLoopBlocks;
LBT.traverse(BottomLoopBlocks, MidBlock, L0->getLoopLatch());
// XXX Assume only one loop for now.
auto L0 = *LI.begin();
auto BottomLoop = IRLoop();
BottomLoop.extractLoopSkeleton(L0);
addBatchArrayForIntermediateVars(L0);
auto TopLoop = IRLoop();
TopLoop.constructEmptyLoop(getLoopTripCount(L0), PreLoopBB ? PreLoopBB : BottomLoop.getHeader());
auto BodyBegin = L0->getHeader()->getTerminator()->getSuccessor(0);
BodyEnd = L0->getLoopLatch()->getSinglePredecessor();
EndBlocks.push_back(BodyEnd);
TopLoop.setLoopBlocks(TopLoopBlocks);
BottomLoop.setLoopBlocks(BottomLoopBlocks);
auto EntryBlock = getPreLoopBlock(L0);
EntryBlock->printAsOperand(errs()); errs() << " - Entry\n";
ExitBlock = L0->getExitBlock();
auto TripCount = getLoopTripCount(L0);
assert (ExitBlock && "Loop must have a single exit block!");
setSuccessor(PreLoopBB, TopLoop.getPreHeader());
// TODO Setting index value to 0 when entering new loop.
BasicBlock * BLoopEntry = BottomLoop.getPreHeader() == &F->getEntryBlock() ?
BottomLoop.getHeader() : BottomLoop.getPreHeader();
setSuccessor(TopLoop.getHeader(), BLoopEntry, 1);
setSuccessor(BottomLoop.getHeader(), PostLoopBB, 1);
}
auto ParentLoop = IRLoop(F->getContext());
ParentLoop.extractLoopSkeleton(L0);
/*
// If FalseBB is terminating instruction, use latch block as target instead.
SmallVector<BasicBlock *, 4> Returns;
getReturnBlocks(F, Returns);
for (auto & Case : SwitchI->cases()) {
if (find(Returns, Case.getCaseSuccessor()) != Returns.end()) {
Case.setSuccessor(L0->getLoopLatch());
}
}
*/
/*
DenseSet<BasicBlock *> Blocks;
auto TruePathBB = NewHeader->getTerminator()->getSuccessor(0);
visitSuccessor(Blocks, TruePathBB, NewLatch);
auto BB = Blocks.begin();
auto BE = Blocks.end();
auto IB = (*BB)->begin();
auto IE = (*BB)->end();
while (BB != BE) {
while (IB != IE) {
auto UB = (*IB).user_begin();
auto UE = (*IB).user_end();
while (UB != UE) {
auto * U = *UB++;
if (auto * Inst = dyn_cast<Instruction>(U)) {
if (Blocks.find(Inst->getParent()) == Blocks.end()) {
auto arrayPtr = createArray(F, (*IB).getType(), 32);
Builder.SetInsertPoint((*IB).getNextNode());
IndexVarVal = Builder.CreateLoad(IndexVar);
auto ptr = Builder.CreateGEP(arrayPtr, {Builder.getInt64(0), IndexVarVal});
auto str = Builder.CreateStore(&*IB, ptr);
errs() << *ptr << "\n" << *str << "\n";
}
}
}
++IB;
}
++BB;
std::vector<IRLoop> Loops;
for (int i = 0; i < AnnotatedVars.size(); ++i) {
Loops.push_back(IRLoop(F->getContext()));
Loops.back().constructEmptyLoop(TripCount, F);
}
}
*/
bool LoopSplitter::run() {
// If no loops, we are done.
if (LI->begin() == LI->end()) return false;
// XXX Assume only one loop for now.
auto L0 = *LI->begin();
Loops.push_back(ParentLoop);
ExitBlock = L0->getExitBlock();
assert (ExitBlock && "Loop must have a single exit block!");
// Stitch all loop skeletons.
setSuccessor(&F->getEntryBlock(), Loops.front().getPreHeader());
for (int i = 0; i < Loops.size()-1; ++i) {
Loops[i].setExitBlock(Loops[i+1].getPreHeader());
}
Loops.back().setExitBlock(ExitBlock);
auto ParentLoop = IRLoop();
ParentLoop.analyze(L0);
std::vector<BasicBlock *> Blocks;
traverseLoopBody(Blocks, BodyBegin);
bool changed = prepareForLoopSplit(F, L0, stat);
if (!changed) return false;
// FIXME One extra set in the end
LoopBlocks.pop_back();
DominatorTree DT (*F);
LoopInfo NewLI (DT);
L0 = *NewLI.begin();
if (NewLI.begin() == NewLI.end()) {
errs() << "Loop info lost, something wrong with preparation\n";
return false;
/*
for (auto & Blocks : LoopBlocks) {
for (auto & BB : Blocks) {
BB->printAsOperand(errs());
errs() << " ";
}
errs() << "\n";
}
auto & SplitBB = LoopSplitEdgeBlocks.front();
*/
doLoopSplit(F, L0, SplitBB);
assert(stat.AnnotatedVarsSize == 1 && "Atleast one annotation must be there!");
for (int i = 0; i < Loops.size(); ++i) {
Loops[i].setLoopBlocks(LoopBlocks[i]);
}
F->print(errs());
return true;
}
void LoopSplitter::traverseLoopBody(std::vector<BasicBlock *> & Blocks,
BasicBlock * Start) {
Blocks.push_back(Start);
auto Ending = std::find(EndBlocks.begin(), EndBlocks.end(), Start) != EndBlocks.end();
if (Ending) {
LoopBlocks.push_back(Blocks);
Blocks.clear();
if (Start == BodyEnd) return;
}
for (auto * BB : successors(Start)) {
traverseLoopBody(Blocks, BB);
}
}
}
......@@ -32,16 +32,26 @@ struct Stats {
};
class LoopSplitter {
using ListOfBlocksType = std::vector<std::vector<llvm::BasicBlock *>>;
llvm::Function * F;
llvm::LoopInfo * LI;
Stats stat;
llvm::SwitchInst * SwitchI;
llvm::SmallVector<llvm::Value *, 4> AnnotatedVars;
llvm::SmallVector<llvm::LoadInst *, 4> VarUsePoints;
llvm::SmallVector<llvm::Instruction *, 4> SplitPoints;
llvm::SmallVector<llvm::BasicBlock *, 4> EndBlocks;
llvm::BasicBlock * BodyEnd;
llvm::SmallVector<std::pair<llvm::BasicBlock *, llvm::BasicBlock *>, 4> LoopBodyRange;
llvm::BasicBlock * ExitBlock;
llvm::DenseMap<const llvm::BasicBlock *, llvm::ConstantInt *> BBToId;
llvm::SmallVector<llvm::BasicBlock *, 4> LoopSplitEdgeBlocks;
bool prepareForLoopSplit(llvm::Function * F, llvm::Loop * L0, Stats & stat);
ListOfBlocksType LoopBlocks;
bool prepareForLoopSplit(llvm::Function * F, Stats & stat);
void fixValueDependenceBetWeenLoops();
void addBatchArrayForIntermediateVars(llvm::Loop * L0);
void traverseLoopBody(std::vector<llvm::BasicBlock *> & LoopBlocks, llvm::BasicBlock * Start);
public:
LoopSplitter(llvm::Function * F_, llvm::LoopInfo * LI_)
: F(F_), LI(LI_) {}
......@@ -49,7 +59,7 @@ public:
bool run();
Stats & getStats() { return stat; }
void addAdapterBasicBlocks(llvm::Loop * L0, llvm::Instruction * SP, llvm::Value * Idx);
void doLoopSplit(llvm::Function * F, llvm::Loop * L0, llvm::BasicBlock * SplitBlock);
//void doLoopSplit(llvm::Function * F, llvm::Loop * L0, llvm::Instruction * SplitPoint);
};
} // tas namespace
......
......@@ -211,21 +211,22 @@ void cloneLoopBasicBlocks(Function * F, Loop * L, ValueToValueMapTy & VMap) {
LoopTerminator->setSuccessor(1, ClonedBlocks.front());
}
void insertLLVMPrefetchIntrinsic(Function * F, Instruction * PtrAllocaUse) {
IRBuilder<> Builder(PtrAllocaUse);
CallInst * insertLLVMPrefetchIntrinsic(Function * F, Instruction * PtrAllocaUse, Instruction * InsertBefore = nullptr) {
IRBuilder<> Builder(InsertBefore?InsertBefore:PtrAllocaUse);
auto Ptr = Builder.CreateLoad(PtrAllocaUse->getOperand(0), "prefetch_load");
auto CastI = Builder.CreateBitCast(Ptr, Builder.getInt8PtrTy(), "prefetch1");
// Add llvm prefetch intrinsic call.
Type *I32 = Type::getInt32Ty(F->getContext());
Value *PrefetchFunc = Intrinsic::getDeclaration(F->getParent(), Intrinsic::prefetch);
Builder.CreateCall(
auto Prefetch = Builder.CreateCall(
PrefetchFunc,
{CastI, // Pointer Value
ConstantInt::get(I32, 0), // read (0) or write (1)
ConstantInt::get(I32, 3), // no_locality (0) to extreme temporal locality (3)
ConstantInt::get(I32, 1)} // data (1) or instruction (0)
);
return Prefetch;
}
/// Replace old value with new value within basic block.
......
......@@ -35,7 +35,7 @@ void setAnnotationInFunctionObject(llvm::Module * M);
void cloneLoopBasicBlocks(llvm::Function * F, llvm::Loop * L,
llvm::ValueToValueMapTy & VMap);
void insertLLVMPrefetchIntrinsic(llvm::Function * F, llvm::Instruction * I);
llvm::CallInst * insertLLVMPrefetchIntrinsic(llvm::Function * F, llvm::Instruction * I, llvm::Instruction * InsertBefore);
void replaceUsesWithinBB(llvm::Value * From, llvm::Value * To,
llvm::BasicBlock * BB);
......
......@@ -99,8 +99,6 @@ int main(int argc, char * argv[]) {
}
if (FnStr.second.compare("tas_batch_maker") == 0) {
tas::BlockPredication BP (FnStr.first);
auto res = BP.run();
// Make Batch version
tas::BatchMaker BM(FnStr.first);
auto BatchFunc = BM.run();
......@@ -108,7 +106,7 @@ int main(int argc, char * argv[]) {
DominatorTree DT(*BatchFunc);
LoopInfo LI(DT);
tas::LoopSplitter LS(BatchFunc, &LI);
res = LS.run();
auto res = LS.run();
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment