diff --git a/src/BatchMaker.cpp b/src/BatchMaker.cpp index f31ce0f91dfb2367417b32ea065c84ab027fdd34..cc596bd35b8ac2b820d31677f2892d1cf5c9a624 100644 --- a/src/BatchMaker.cpp +++ b/src/BatchMaker.cpp @@ -45,6 +45,7 @@ void BatchMaker::createBatchedFormFnPrototype(vector & BatchFuncArgL } // Batch size parameter + BatchSizeArgPos = i; BatchFuncArgList.emplace_back( TASArgAttr { false, i++, Type::getInt32Ty(Ctx), nullptr, BatchSizeVarName}); @@ -85,6 +86,10 @@ void BatchMaker::replaceOldArgUsesWithBatchArgs(vector & BatchFuncAr [&] (TASArgAttr & Attr) { if (Attr.IsBatch) BatchArgs.push_back(Attr.Val); }); + Builder.SetInsertPoint(&EntryBB->front()); + BatchSizeAlloca = Builder.CreateAlloca(BatchFuncArgList[BatchSizeArgPos].Ty); + Builder.CreateStore(BatchFuncArgList[BatchSizeArgPos].Val, BatchSizeAlloca); + for (auto & BatchArg : BatchArgs) { Builder.SetInsertPoint(&EntryBB->front()); auto BatchArgAlloca = Builder.CreateAlloca(BatchArg->getType()); @@ -179,10 +184,8 @@ void BatchMaker::addBatchLoop(BasicBlock * RetBlock, AllocaInst * IdxPtr) { [&] (BasicBlock * BB) { setSuccessor(BB, UniqueExitingBlock); }); } - auto BatchSizeVal = BatchFunc->getValueSymbolTable()->lookup(BatchSizeVarName); - auto TL0 = TASForLoop(BatchFunc->getContext(), ParentBB, RetBlock, - std::string("loop0"), BatchFunc, BatchSizeVal, IdxPtr); + std::string("loop0"), BatchFunc, BatchSizeAlloca, IdxPtr); TL0.setLoopBody(BatchBB, UniqueExitingBlock); } diff --git a/src/BatchMaker.h b/src/BatchMaker.h index 021e0581cafb66c79b00a16fbda8fc43ed591590..ba868fbed323aeef7801c0fd7376c88bcde02fb8 100644 --- a/src/BatchMaker.h +++ b/src/BatchMaker.h @@ -40,6 +40,8 @@ class BatchMaker { bool IsRetTyVoid = false; std::string BatchSizeVarName = std::string("TAS_BatchSize"); std::string ReturnVarName = std::string("TAS_ReturnVar"); + unsigned BatchSizeArgPos; + llvm::AllocaInst * BatchSizeAlloca; void createBatchedFormFnPrototype(std::vector & BatchFuncArgList); void addBatchLoop(llvm::BasicBlock * RetBlock, llvm::AllocaInst * IdxPtr); diff --git a/src/ForLoop.cpp b/src/ForLoop.cpp index 7d1a2c44883bd00f7c32b10bf91c302346c6516d..3035eb82368887ae4be5122390f46873018d1650 100644 --- a/src/ForLoop.cpp +++ b/src/ForLoop.cpp @@ -10,7 +10,7 @@ namespace tas { TASForLoop::TASForLoop(LLVMContext & Ctx, BasicBlock * Prev, BasicBlock * Next, const std::string & Name, Function * F, - llvm::Value * TC, AllocaInst * IP) + llvm::AllocaInst * TC, AllocaInst * IP) : F(F), Name (std::move(Name)), TripCount(TC), IdxVarPtr(IP) { addEmptyLoop(Ctx, Prev, Next); @@ -32,8 +32,9 @@ void TASForLoop::addEmptyLoop(LLVMContext & Ctx, BasicBlock * Prev, addIncrementIndexOp(IdxVarPtr, BI); Builder.SetInsertPoint(Header); + auto TCVal = Builder.CreateLoad(TripCount); IndexVar = Builder.CreateLoad(IdxVarPtr); - auto * icmp = Builder.CreateICmpSLT(IndexVar, TripCount, "loop-predicate"); + auto * icmp = Builder.CreateICmpSLT(IndexVar, TCVal, "loop-predicate"); // Stitch entry point in control flow. if (Prev) { diff --git a/src/ForLoop.h b/src/ForLoop.h index 5bc763824bab2cc15b64f95c7ab615304df118c9..009811fe97eaf69956015135fa1521580426a412 100644 --- a/src/ForLoop.h +++ b/src/ForLoop.h @@ -29,7 +29,7 @@ public: explicit TASForLoop(llvm::LLVMContext & Ctx, llvm::BasicBlock * Prev, llvm::BasicBlock * Next, const std::string & Name, llvm::Function * F = nullptr, - llvm::Value * TC = nullptr, llvm::AllocaInst * IP = nullptr); + llvm::AllocaInst * TC = nullptr, llvm::AllocaInst * IP = nullptr); void addEmptyLoop(llvm::LLVMContext & Ctx, llvm::BasicBlock * Prev, llvm::BasicBlock * Next); void setLoopBody(llvm::BasicBlock * BodyBB); diff --git a/src/ForLoopV2.cpp b/src/ForLoopV2.cpp index 9853e31005bb167d6f834b29b4406d24323cf739..1205321dad8bbcf20374d274bb17590919f87853 100644 --- a/src/ForLoopV2.cpp +++ b/src/ForLoopV2.cpp @@ -23,7 +23,7 @@ void IRLoop::analyze(Loop * L) { } } -void IRLoop::constructEmptyLoop(Value * TripCount, BasicBlock * ExitBlock) { +void IRLoop::constructEmptyLoop(AllocaInst * TripCount, BasicBlock * ExitBlock) { auto & Ctx = ExitBlock->getContext(); const auto & F = ExitBlock->getParent(); @@ -54,25 +54,37 @@ void IRLoop::constructEmptyLoop(Value * TripCount, BasicBlock * ExitBlock) { // Populate header block Builder.SetInsertPoint(Header); auto IdxVal = Builder.CreateLoad(IdxAlloca); - auto * Icmp = Builder.CreateICmpSLT(IdxVal, TripCount, "loop-predicate"); + auto TC = Builder.CreateLoad(TripCount); + auto * Icmp = Builder.CreateICmpSLT(IdxVal, TC, "loop-predicate"); Builder.CreateCondBr(Icmp, EmptyBody, ExitBlock); } -void IRLoop::setLoopBlocks(SmallVectorImpl & Blocks) { - for_each(Blocks, [&] (auto & BB) { Blocks.push_back(BB); }); +void IRLoop::setLoopBlocks(std::vector & BlockList) { + for_each(BlockList, [&] (auto & BB) { Blocks.push_back(BB); }); assert (!Blocks.empty() && "Blocks can't be empty"); setSuccessor(Header, Blocks.front()); // True Path assert (Latch && "Latch is NULL"); setSuccessor(Blocks.back(), Latch); } -void LoopBodyTraverser::traverse(SmallVectorImpl & Blocks, +void LoopBodyTraverser::traverse(std::vector & Blocks, BasicBlock * Start, BasicBlock * End) { + if (std::find(Blocks.begin(), Blocks.end(), Start) != Blocks.end()) return; if (Start == End) return; Blocks.push_back(Start); for (auto * BB : successors(Start)){ // Don't travel out of loop - //if (std::find(ExitBlocks.begin(), ExitBlocks.end(), BB) != Blocks.end()) continue; + if (std::find(ExitBlocks.begin(), ExitBlocks.end(), BB) != ExitBlocks.end()) continue; + traverse(Blocks, BB, End); + } +} + +void LoopBodyTraverser::traverseReverse(std::vector & Blocks, + BasicBlock * Start, BasicBlock * End) { + if (std::find(Blocks.begin(), Blocks.end(), Start) != Blocks.end()) return; + if (Start == End) return; + Blocks.push_back(Start); + for (auto * BB : predecessors(Start)){ traverse(Blocks, BB, End); } } diff --git a/src/ForLoopV2.h b/src/ForLoopV2.h index 7e77f81aa4a72778a1b9f273121fb914e9fbcf2e..300be101c335d08212ccc3261a13947155bfd9b4 100644 --- a/src/ForLoopV2.h +++ b/src/ForLoopV2.h @@ -23,10 +23,10 @@ public: IRLoop() = default; void analyze(llvm::Loop * L); - void constructEmptyLoop(llvm::Value * TripCount, + void constructEmptyLoop(llvm::AllocaInst * TripCount, llvm::BasicBlock * InsertAfter); void extractLoopSkeleton(llvm::Loop * L); - void setLoopBlocks(llvm::SmallVectorImpl & Blocks); + void setLoopBlocks(std::vector & Blocks); llvm::BasicBlock * getPreHeader() { return PreHeader; @@ -35,6 +35,16 @@ public: llvm::BasicBlock * getHeader() { return Header; } + + void printLooopInfo() { + llvm::errs() << "LoopInfo:"; + Header->printAsOperand(llvm::errs()); + llvm::errs() << " "; + Latch->printAsOperand(llvm::errs()); + llvm::errs() << " "; + llvm::errs() << *IdxAlloca << "\n"; + llvm::errs() << " No of Blocks = " << Blocks.size() << "\n"; + } }; class LoopBodyTraverser { @@ -45,7 +55,10 @@ public: L->getExitBlocks(ExitBlocks); } - void traverse(llvm::SmallVectorImpl & Blocks, + void traverse(std::vector & Blocks, + llvm::BasicBlock * Start, llvm::BasicBlock * End); + + void traverseReverse(std::vector & Blocks, llvm::BasicBlock * Start, llvm::BasicBlock * End); void printExitBlocks() { diff --git a/src/LoopSplitter.cpp b/src/LoopSplitter.cpp index 28880f059d89008560d1e43edb942aa0f46e3611..0e312730e0d964e26a4fa67c338f8988533c220b 100644 --- a/src/LoopSplitter.cpp +++ b/src/LoopSplitter.cpp @@ -10,8 +10,10 @@ #include #include +#include #include +using namespace std; using namespace llvm; #define DEBUG_TYPE "tas-batch-process" @@ -30,8 +32,6 @@ void LoopSplitter::addAdapterBasicBlocks(Loop * L, Instruction * SP, Value * Idx setSuccessor(CollectBB, DistBB); LoopSplitEdgeBlocks.push_back(CollectBB); - // General case: Find all edges from basic blocks above tophalf - // and connect it to CollectBB, use switch Inst IRBuilder<> Builder(&F->getEntryBlock().front()); auto BrTgtArray = createArray(F, Builder.getInt32Ty(), 32 /*XXX Max Batch size*/); @@ -42,23 +42,12 @@ void LoopSplitter::addAdapterBasicBlocks(Loop * L, Instruction * SP, Value * Idx auto BrVal = Builder.CreateLoad(BrValPtr); SwitchI = Builder.CreateSwitch(BrVal, BottomHalf); - // Handle diverge blocks - auto OldHeader = L->getHeader(); - auto OldEntry = cast(OldHeader->getTerminator())->getSuccessor(0); - - SmallVector TopBlocks; - SmallVector BottomBlocks; - LoopBodyTraverser LBT (L); - LBT.traverse(TopBlocks, OldEntry, DistBB); - LBT.traverse(BottomBlocks, DistBB, L->getLoopLatch()); + writeToAsmFile(*F->getParent()); SmallVector DivergeBlocks; - for (auto & BB : TopBlocks) { - for (auto * Succ : successors(BB)) { - Succ->printAsOperand(errs()); - if (std::find(BottomBlocks.begin(), BottomBlocks.end(), Succ) != BottomBlocks.end()) - DivergeBlocks.push_back(BB); - } + auto DivergeBlock = TopHalf->getUniquePredecessor(); + if (DivergeBlock != L->getHeader()) { + DivergeBlocks.push_back(DivergeBlock); } for (auto & DivergeBB : DivergeBlocks) { @@ -115,24 +104,27 @@ void LoopSplitter::doLoopSplit(Function * F, Loop * L0, BasicBlock * SplitBlock) auto LBT = LoopBodyTraverser(L0); // Collect Blocks in range [OldEntry, MidBlock) - SmallVector TopLoopBlocks; + vector TopLoopBlocks; LBT.traverse(TopLoopBlocks, OldEntry, MidBlock); // Collect Blocks in range [MidBlock, Latch) - SmallVector BottomLoopBlocks; + vector BottomLoopBlocks; LBT.traverse(BottomLoopBlocks, MidBlock, L0->getLoopLatch()); auto BottomLoop = IRLoop(); BottomLoop.extractLoopSkeleton(L0); auto TopLoop = IRLoop(); - TopLoop.constructEmptyLoop(getLoopTripCount(L0), BottomLoop.getHeader()); + TopLoop.constructEmptyLoop(getLoopTripCount(L0), PreLoopBB ? PreLoopBB : BottomLoop.getHeader()); TopLoop.setLoopBlocks(TopLoopBlocks); BottomLoop.setLoopBlocks(BottomLoopBlocks); setSuccessor(PreLoopBB, TopLoop.getPreHeader()); - setSuccessor(TopLoop.getHeader(), BottomLoop.getHeader(), 1); + // TODO Setting index value to 0 when entering new loop. + BasicBlock * BLoopEntry = BottomLoop.getPreHeader() == &F->getEntryBlock() ? + BottomLoop.getHeader() : BottomLoop.getPreHeader(); + setSuccessor(TopLoop.getHeader(), BLoopEntry, 1); setSuccessor(BottomLoop.getHeader(), PostLoopBB, 1); /* @@ -197,8 +189,6 @@ bool LoopSplitter::run() { bool changed = prepareForLoopSplit(F, L0, stat); if (!changed) return false; -// F->print(errs()); - DominatorTree DT (*F); LoopInfo NewLI (DT); L0 = *NewLI.begin(); @@ -207,7 +197,10 @@ bool LoopSplitter::run() { return false; } auto & SplitBB = LoopSplitEdgeBlocks.front(); -// doLoopSplit(F, L0, SplitBB); + + doLoopSplit(F, L0, SplitBB); + + assert(stat.AnnotatedVarsSize == 1 && "Atleast one annotation must be there!"); return true; } diff --git a/src/Util.cpp b/src/Util.cpp index 8d1f0ff19c40bf7e0a777a224a5b78c0aa819cc3..94dae9a8419f4718c06c22f09e8ab94c6a04cea5 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -115,7 +115,6 @@ SmallVector detectExpPtrUses(SmallVectorImpl & Annotated for_each(AnnotatedVars, [&] (const auto & Var) { auto FU = findEarliestPointerDerefInstruction(Var); - errs() << *FU << "\n"; if (!FU) return; VarUsePoints.push_back(const_cast(FU)); }); @@ -480,10 +479,11 @@ void visitSuccessor(SmallVectorImpl & Blocks, BasicBlock * StartBl } } -Value * getLoopTripCount(Loop * L0) { +AllocaInst * getLoopTripCount(Loop * L0) { auto Header = L0->getHeader(); auto Cond = cast(Header->getTerminator())->getCondition(); - return cast(Cond)->getOperand(1); + auto TCVal = cast(Cond)->getOperand(1); + return cast(cast(TCVal)->getOperand(0)); } BasicBlock * getPreLoopBlock(Loop * L) { diff --git a/src/Util.h b/src/Util.h index 641fd043949c5e6f734f21443f791e7eda0362c4..968a44bd261d333eb18aaf613ee07c1f4943a943 100644 --- a/src/Util.h +++ b/src/Util.h @@ -112,7 +112,7 @@ void visitSuccessor(llvm::SmallVectorImpl & Blocks, llvm::BasicBlock * CurBlock, llvm::BasicBlock * EndBlock); -llvm::Value * getLoopTripCount(llvm::Loop * L0); +llvm::AllocaInst * getLoopTripCount(llvm::Loop * L0); llvm::BasicBlock * getPreLoopBlock(llvm::Loop * L); diff --git a/test/unittests/blockPredication_test.cpp b/test/unittests/blockPredication_test.cpp index 16f0f3ea105163c42cbd3fbe012cc51f7e38ce4a..a0373110bd7d6c8d7f9359ec666e9524b5b76000 100644 --- a/test/unittests/blockPredication_test.cpp +++ b/test/unittests/blockPredication_test.cpp @@ -167,7 +167,6 @@ TEST_CASE("fast_flows_packet") { REQUIRE(M != nullptr); auto F = M->getFunction("fast_flows_packet"); - errs() << "Running block predication\n"; BlockPredication BP(F); BP.run(); diff --git a/test/unittests/loopSplitter_test.cpp b/test/unittests/loopSplitter_test.cpp index 59c44b7fd00d55977a3d887ef420a672be37ea91..312d59a087a0c123a045fcfc18c057440d556d5c 100644 --- a/test/unittests/loopSplitter_test.cpp +++ b/test/unittests/loopSplitter_test.cpp @@ -61,9 +61,8 @@ TEST_CASE("fn with single loop") { LoopSplitter LS(F, &LI); LS.run(); - auto Stats = LS.getStats(); - REQUIRE(Stats.AnnotatedVarsSize == 1); - REQUIRE(Stats.VarUsePointsSize == 1); + REQUIRE(LS.getStats().AnnotatedVarsSize == 1); + REQUIRE(LS.getStats().VarUsePointsSize == 1); //F->print(errs()); auto asmFile = writeToAsmFile(*M); @@ -76,3 +75,24 @@ TEST_CASE("fn with single loop") { auto ret = system(binary.c_str()); REQUIRE(ret == 0); } + +TEST_CASE("fast_flows_packet loop split") { + std::string filePrefix = "fast_flows"; + auto M = parseIR(generateIR(filePrefix + string(".c"), input_dir, true), input_dir); + REQUIRE(M != nullptr); + M->setSourceFileName(filePrefix + string("_batch.ll")); + { + std::string functionName = "fast_flows_packet"; + + auto F = M->getFunction(functionName); + DominatorTree DT(*F); + LoopInfo LI(DT); + + LoopSplitter LS(F, &LI); + LS.run(); + + //F->print(errs()); + auto asmFile = writeToAsmFile(*M); + auto TestObject = generateObject(asmFile); + } +} diff --git a/tools/tasopt.cpp b/tools/tasopt.cpp index c4ecfd89cb93ccaaabec0df93d76286c6205de83..73c41b4232cf5c0255d6c14ebc47aba277b3f370 100644 --- a/tools/tasopt.cpp +++ b/tools/tasopt.cpp @@ -77,7 +77,6 @@ int main(int argc, char * argv[]) { for (auto & FnStr : FnLists) { if (FnStr.second.compare("tas_block_predicate") != 0) continue; // Block Predication - errs() << "Running block predication on " << FnStr.second << "\n"; tas::BlockPredication BP(FnStr.first); auto res = BP.run(); if (!res) { @@ -90,6 +89,7 @@ int main(int argc, char * argv[]) { tas::BatchMaker BM(FnStr.first); auto BatchFunc = BM.run(); + writeToAsmFile(*M); // Loop Splitting DominatorTree DT(*BatchFunc); LoopInfo LI(DT);