Commit 4b09d5fa authored by guruhegde's avatar guruhegde

LoopSplitter - fix issues

parent 280aaa26
......@@ -45,6 +45,7 @@ void BatchMaker::createBatchedFormFnPrototype(vector<TASArgAttr> & BatchFuncArgL
}
// Batch size parameter
BatchSizeArgPos = i;
BatchFuncArgList.emplace_back(
TASArgAttr { false, i++, Type::getInt32Ty(Ctx), nullptr, BatchSizeVarName});
......@@ -85,6 +86,10 @@ void BatchMaker::replaceOldArgUsesWithBatchArgs(vector<TASArgAttr> & BatchFuncAr
[&] (TASArgAttr & Attr) {
if (Attr.IsBatch) BatchArgs.push_back(Attr.Val); });
Builder.SetInsertPoint(&EntryBB->front());
BatchSizeAlloca = Builder.CreateAlloca(BatchFuncArgList[BatchSizeArgPos].Ty);
Builder.CreateStore(BatchFuncArgList[BatchSizeArgPos].Val, BatchSizeAlloca);
for (auto & BatchArg : BatchArgs) {
Builder.SetInsertPoint(&EntryBB->front());
auto BatchArgAlloca = Builder.CreateAlloca(BatchArg->getType());
......@@ -179,10 +184,8 @@ void BatchMaker::addBatchLoop(BasicBlock * RetBlock, AllocaInst * IdxPtr) {
[&] (BasicBlock * BB) { setSuccessor(BB, UniqueExitingBlock); });
}
auto BatchSizeVal = BatchFunc->getValueSymbolTable()->lookup(BatchSizeVarName);
auto TL0 = TASForLoop(BatchFunc->getContext(), ParentBB, RetBlock,
std::string("loop0"), BatchFunc, BatchSizeVal, IdxPtr);
std::string("loop0"), BatchFunc, BatchSizeAlloca, IdxPtr);
TL0.setLoopBody(BatchBB, UniqueExitingBlock);
}
......
......@@ -40,6 +40,8 @@ class BatchMaker {
bool IsRetTyVoid = false;
std::string BatchSizeVarName = std::string("TAS_BatchSize");
std::string ReturnVarName = std::string("TAS_ReturnVar");
unsigned BatchSizeArgPos;
llvm::AllocaInst * BatchSizeAlloca;
void createBatchedFormFnPrototype(std::vector<TASArgAttr> & BatchFuncArgList);
void addBatchLoop(llvm::BasicBlock * RetBlock, llvm::AllocaInst * IdxPtr);
......
......@@ -10,7 +10,7 @@ namespace tas {
TASForLoop::TASForLoop(LLVMContext & Ctx, BasicBlock * Prev,
BasicBlock * Next, const std::string & Name, Function * F,
llvm::Value * TC, AllocaInst * IP)
llvm::AllocaInst * TC, AllocaInst * IP)
: F(F), Name (std::move(Name)), TripCount(TC), IdxVarPtr(IP)
{
addEmptyLoop(Ctx, Prev, Next);
......@@ -32,8 +32,9 @@ void TASForLoop::addEmptyLoop(LLVMContext & Ctx, BasicBlock * Prev,
addIncrementIndexOp(IdxVarPtr, BI);
Builder.SetInsertPoint(Header);
auto TCVal = Builder.CreateLoad(TripCount);
IndexVar = Builder.CreateLoad(IdxVarPtr);
auto * icmp = Builder.CreateICmpSLT(IndexVar, TripCount, "loop-predicate");
auto * icmp = Builder.CreateICmpSLT(IndexVar, TCVal, "loop-predicate");
// Stitch entry point in control flow.
if (Prev) {
......
......@@ -29,7 +29,7 @@ public:
explicit TASForLoop(llvm::LLVMContext & Ctx,
llvm::BasicBlock * Prev, llvm::BasicBlock * Next,
const std::string & Name, llvm::Function * F = nullptr,
llvm::Value * TC = nullptr, llvm::AllocaInst * IP = nullptr);
llvm::AllocaInst * TC = nullptr, llvm::AllocaInst * IP = nullptr);
void addEmptyLoop(llvm::LLVMContext & Ctx, llvm::BasicBlock * Prev, llvm::BasicBlock * Next);
void setLoopBody(llvm::BasicBlock * BodyBB);
......
......@@ -23,7 +23,7 @@ void IRLoop::analyze(Loop * L) {
}
}
void IRLoop::constructEmptyLoop(Value * TripCount, BasicBlock * ExitBlock) {
void IRLoop::constructEmptyLoop(AllocaInst * TripCount, BasicBlock * ExitBlock) {
auto & Ctx = ExitBlock->getContext();
const auto & F = ExitBlock->getParent();
......@@ -54,25 +54,37 @@ void IRLoop::constructEmptyLoop(Value * TripCount, BasicBlock * ExitBlock) {
// Populate header block
Builder.SetInsertPoint(Header);
auto IdxVal = Builder.CreateLoad(IdxAlloca);
auto * Icmp = Builder.CreateICmpSLT(IdxVal, TripCount, "loop-predicate");
auto TC = Builder.CreateLoad(TripCount);
auto * Icmp = Builder.CreateICmpSLT(IdxVal, TC, "loop-predicate");
Builder.CreateCondBr(Icmp, EmptyBody, ExitBlock);
}
void IRLoop::setLoopBlocks(SmallVectorImpl<BasicBlock *> & Blocks) {
for_each(Blocks, [&] (auto & BB) { Blocks.push_back(BB); });
void IRLoop::setLoopBlocks(std::vector<BasicBlock *> & BlockList) {
for_each(BlockList, [&] (auto & BB) { Blocks.push_back(BB); });
assert (!Blocks.empty() && "Blocks can't be empty");
setSuccessor(Header, Blocks.front()); // True Path
assert (Latch && "Latch is NULL");
setSuccessor(Blocks.back(), Latch);
}
void LoopBodyTraverser::traverse(SmallVectorImpl<BasicBlock *> & Blocks,
void LoopBodyTraverser::traverse(std::vector<BasicBlock *> & Blocks,
BasicBlock * Start, BasicBlock * End) {
if (std::find(Blocks.begin(), Blocks.end(), Start) != Blocks.end()) return;
if (Start == End) return;
Blocks.push_back(Start);
for (auto * BB : successors(Start)){
// Don't travel out of loop
//if (std::find(ExitBlocks.begin(), ExitBlocks.end(), BB) != Blocks.end()) continue;
if (std::find(ExitBlocks.begin(), ExitBlocks.end(), BB) != ExitBlocks.end()) continue;
traverse(Blocks, BB, End);
}
}
void LoopBodyTraverser::traverseReverse(std::vector<BasicBlock *> & Blocks,
BasicBlock * Start, BasicBlock * End) {
if (std::find(Blocks.begin(), Blocks.end(), Start) != Blocks.end()) return;
if (Start == End) return;
Blocks.push_back(Start);
for (auto * BB : predecessors(Start)){
traverse(Blocks, BB, End);
}
}
......
......@@ -23,10 +23,10 @@ public:
IRLoop() = default;
void analyze(llvm::Loop * L);
void constructEmptyLoop(llvm::Value * TripCount,
void constructEmptyLoop(llvm::AllocaInst * TripCount,
llvm::BasicBlock * InsertAfter);
void extractLoopSkeleton(llvm::Loop * L);
void setLoopBlocks(llvm::SmallVectorImpl<llvm::BasicBlock *> & Blocks);
void setLoopBlocks(std::vector<llvm::BasicBlock *> & Blocks);
llvm::BasicBlock * getPreHeader() {
return PreHeader;
......@@ -35,6 +35,16 @@ public:
llvm::BasicBlock * getHeader() {
return Header;
}
void printLooopInfo() {
llvm::errs() << "LoopInfo:";
Header->printAsOperand(llvm::errs());
llvm::errs() << " ";
Latch->printAsOperand(llvm::errs());
llvm::errs() << " ";
llvm::errs() << *IdxAlloca << "\n";
llvm::errs() << " No of Blocks = " << Blocks.size() << "\n";
}
};
class LoopBodyTraverser {
......@@ -45,7 +55,10 @@ public:
L->getExitBlocks(ExitBlocks);
}
void traverse(llvm::SmallVectorImpl<llvm::BasicBlock *> & Blocks,
void traverse(std::vector<llvm::BasicBlock *> & Blocks,
llvm::BasicBlock * Start, llvm::BasicBlock * End);
void traverseReverse(std::vector<llvm::BasicBlock *> & Blocks,
llvm::BasicBlock * Start, llvm::BasicBlock * End);
void printExitBlocks() {
......
......@@ -10,8 +10,10 @@
#include <llvm/IR/IRBuilder.h>
#include <iostream>
#include <vector>
#include <string>
using namespace std;
using namespace llvm;
#define DEBUG_TYPE "tas-batch-process"
......@@ -30,8 +32,6 @@ void LoopSplitter::addAdapterBasicBlocks(Loop * L, Instruction * SP, Value * Idx
setSuccessor(CollectBB, DistBB);
LoopSplitEdgeBlocks.push_back(CollectBB);
// General case: Find all edges from basic blocks above tophalf
// and connect it to CollectBB, use switch Inst
IRBuilder<> Builder(&F->getEntryBlock().front());
auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/);
......@@ -42,23 +42,12 @@ void LoopSplitter::addAdapterBasicBlocks(Loop * L, Instruction * SP, Value * Idx
auto BrVal = Builder.CreateLoad(BrValPtr);
SwitchI = Builder.CreateSwitch(BrVal, BottomHalf);
// Handle diverge blocks
auto OldHeader = L->getHeader();
auto OldEntry = cast<BranchInst>(OldHeader->getTerminator())->getSuccessor(0);
SmallVector<BasicBlock *, 4> TopBlocks;
SmallVector<BasicBlock *, 4> BottomBlocks;
LoopBodyTraverser LBT (L);
LBT.traverse(TopBlocks, OldEntry, DistBB);
LBT.traverse(BottomBlocks, DistBB, L->getLoopLatch());
writeToAsmFile(*F->getParent());
SmallVector<BasicBlock *, 4> DivergeBlocks;
for (auto & BB : TopBlocks) {
for (auto * Succ : successors(BB)) {
Succ->printAsOperand(errs());
if (std::find(BottomBlocks.begin(), BottomBlocks.end(), Succ) != BottomBlocks.end())
DivergeBlocks.push_back(BB);
}
auto DivergeBlock = TopHalf->getUniquePredecessor();
if (DivergeBlock != L->getHeader()) {
DivergeBlocks.push_back(DivergeBlock);
}
for (auto & DivergeBB : DivergeBlocks) {
......@@ -115,24 +104,27 @@ void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock)
auto LBT = LoopBodyTraverser(L0);
// Collect Blocks in range [OldEntry, MidBlock)
SmallVector<BasicBlock *, 4> TopLoopBlocks;
vector<BasicBlock *> TopLoopBlocks;
LBT.traverse(TopLoopBlocks, OldEntry, MidBlock);
// Collect Blocks in range [MidBlock, Latch)
SmallVector<BasicBlock *, 4> BottomLoopBlocks;
vector<BasicBlock *> BottomLoopBlocks;
LBT.traverse(BottomLoopBlocks, MidBlock, L0->getLoopLatch());
auto BottomLoop = IRLoop();
BottomLoop.extractLoopSkeleton(L0);
auto TopLoop = IRLoop();
TopLoop.constructEmptyLoop(getLoopTripCount(L0), BottomLoop.getHeader());
TopLoop.constructEmptyLoop(getLoopTripCount(L0), PreLoopBB ? PreLoopBB : BottomLoop.getHeader());
TopLoop.setLoopBlocks(TopLoopBlocks);
BottomLoop.setLoopBlocks(BottomLoopBlocks);
setSuccessor(PreLoopBB, TopLoop.getPreHeader());
setSuccessor(TopLoop.getHeader(), BottomLoop.getHeader(), 1);
// TODO Setting index value to 0 when entering new loop.
BasicBlock * BLoopEntry = BottomLoop.getPreHeader() == &F->getEntryBlock() ?
BottomLoop.getHeader() : BottomLoop.getPreHeader();
setSuccessor(TopLoop.getHeader(), BLoopEntry, 1);
setSuccessor(BottomLoop.getHeader(), PostLoopBB, 1);
/*
......@@ -197,8 +189,6 @@ bool LoopSplitter::run() {
bool changed = prepareForLoopSplit(F, L0, stat);
if (!changed) return false;
// F->print(errs());
DominatorTree DT (*F);
LoopInfo NewLI (DT);
L0 = *NewLI.begin();
......@@ -207,7 +197,10 @@ bool LoopSplitter::run() {
return false;
}
auto & SplitBB = LoopSplitEdgeBlocks.front();
// doLoopSplit(F, L0, SplitBB);
doLoopSplit(F, L0, SplitBB);
assert(stat.AnnotatedVarsSize == 1 && "Atleast one annotation must be there!");
return true;
}
......
......@@ -115,7 +115,6 @@ SmallVector<LoadInst *, 4> detectExpPtrUses(SmallVectorImpl<Value *> & Annotated
for_each(AnnotatedVars, [&]
(const auto & Var) {
auto FU = findEarliestPointerDerefInstruction(Var);
errs() << *FU << "\n";
if (!FU) return;
VarUsePoints.push_back(const_cast<LoadInst *>(FU));
});
......@@ -480,10 +479,11 @@ void visitSuccessor(SmallVectorImpl<BasicBlock *> & Blocks, BasicBlock * StartBl
}
}
Value * getLoopTripCount(Loop * L0) {
AllocaInst * getLoopTripCount(Loop * L0) {
auto Header = L0->getHeader();
auto Cond = cast<BranchInst>(Header->getTerminator())->getCondition();
return cast<ICmpInst>(Cond)->getOperand(1);
auto TCVal = cast<ICmpInst>(Cond)->getOperand(1);
return cast<AllocaInst>(cast<LoadInst>(TCVal)->getOperand(0));
}
BasicBlock * getPreLoopBlock(Loop * L) {
......
......@@ -112,7 +112,7 @@ void visitSuccessor(llvm::SmallVectorImpl<llvm::BasicBlock *> & Blocks,
llvm::BasicBlock * CurBlock,
llvm::BasicBlock * EndBlock);
llvm::Value * getLoopTripCount(llvm::Loop * L0);
llvm::AllocaInst * getLoopTripCount(llvm::Loop * L0);
llvm::BasicBlock * getPreLoopBlock(llvm::Loop * L);
......
......@@ -167,7 +167,6 @@ TEST_CASE("fast_flows_packet") {
REQUIRE(M != nullptr);
auto F = M->getFunction("fast_flows_packet");
errs() << "Running block predication\n";
BlockPredication BP(F);
BP.run();
......
......@@ -61,9 +61,8 @@ TEST_CASE("fn with single loop") {
LoopSplitter LS(F, &LI);
LS.run();
auto Stats = LS.getStats();
REQUIRE(Stats.AnnotatedVarsSize == 1);
REQUIRE(Stats.VarUsePointsSize == 1);
REQUIRE(LS.getStats().AnnotatedVarsSize == 1);
REQUIRE(LS.getStats().VarUsePointsSize == 1);
//F->print(errs());
auto asmFile = writeToAsmFile(*M);
......@@ -76,3 +75,24 @@ TEST_CASE("fn with single loop") {
auto ret = system(binary.c_str());
REQUIRE(ret == 0);
}
TEST_CASE("fast_flows_packet loop split") {
std::string filePrefix = "fast_flows";
auto M = parseIR(generateIR(filePrefix + string(".c"), input_dir, true), input_dir);
REQUIRE(M != nullptr);
M->setSourceFileName(filePrefix + string("_batch.ll"));
{
std::string functionName = "fast_flows_packet";
auto F = M->getFunction(functionName);
DominatorTree DT(*F);
LoopInfo LI(DT);
LoopSplitter LS(F, &LI);
LS.run();
//F->print(errs());
auto asmFile = writeToAsmFile(*M);
auto TestObject = generateObject(asmFile);
}
}
......@@ -77,7 +77,6 @@ int main(int argc, char * argv[]) {
for (auto & FnStr : FnLists) {
if (FnStr.second.compare("tas_block_predicate") != 0) continue;
// Block Predication
errs() << "Running block predication on " << FnStr.second << "\n";
tas::BlockPredication BP(FnStr.first);
auto res = BP.run();
if (!res) {
......@@ -90,6 +89,7 @@ int main(int argc, char * argv[]) {
tas::BatchMaker BM(FnStr.first);
auto BatchFunc = BM.run();
writeToAsmFile(*M);
// Loop Splitting
DominatorTree DT(*BatchFunc);
LoopInfo LI(DT);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment