diff --git a/llvm/include/llvm/Transforms/Utils/MatrixUtils.h b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/MatrixUtils.h @@ -0,0 +1,72 @@ +//===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for lowering intrinsics. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H +#define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/IRBuilder.h" + +namespace llvm { +class DomTreeUpdater; +class BasicBlock; +class Value; +class Loop; +class LoopInfo; + +struct TileInfo { + /// Number of rows of the matrix. + unsigned NumRows; + + /// Number of columns of the matrix. + unsigned NumColumns; + + /// Number of columns of the first matrix of a multiply / + /// number of rows of the second matrix of a multiply. + unsigned NumInner; + + /// Number of rows/columns in a tile. + unsigned TileSize = -1; + + /// Start row of the current tile to compute. + Value *CurrentRow; + + /// Start column of the current tile to compute. + Value *CurrentCol; + + /// Current tile offset during the tile computation. + Value *CurrentK; + + BasicBlock *ColumnLoopHeader = nullptr; + BasicBlock *RowLoopHeader = nullptr; + BasicBlock *RowLoopLatch = nullptr; + BasicBlock *InnerLoopHeader = nullptr; + BasicBlock *InnerLoopLatch = nullptr; + + TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner, + unsigned TileSize) + : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner), + TileSize(TileSize) {} + + BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End, + IRBuilder<> &B, DomTreeUpdater &DTU, + LoopInfo &LI); + +private: + static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, + Value *Bound, Value *Step, StringRef Name, + IRBuilder<> &B, DomTreeUpdater &DTU, Loop *L, + LoopInfo &LI); +}; +} // namespace llvm + +#endif diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -44,6 +44,7 @@ LowerInvoke.cpp LowerMemIntrinsics.cpp LowerSwitch.cpp + MatrixUtils.cpp Mem2Reg.cpp MetaRenamer.cpp MisExpect.cpp diff --git a/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/llvm/lib/Transforms/Utils/MatrixUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Utils/MatrixUtils.cpp @@ -0,0 +1,104 @@ +//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for lowering intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/MatrixUtils.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" + +using namespace llvm; + +BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, + Value *Bound, Value *Step, StringRef Name, + IRBuilder<> &B, DomTreeUpdater &DTU, Loop *L, + LoopInfo &LI) { + LLVMContext &Ctx = Preheader->getContext(); + BasicBlock *Header = BasicBlock::Create( + Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); + BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", + Header->getParent(), Exit); + BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", + Header->getParent(), Exit); + + Type *I32Ty = Type::getInt32Ty(Ctx); + BranchInst::Create(Body, Header); + BranchInst::Create(Latch, Body); + PHINode *IV = + PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()); + IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); + + B.SetInsertPoint(Latch); + Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); + Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); + BranchInst::Create(Header, Exit, Cond, Latch); + IV->addIncoming(Inc, Latch); + + BranchInst *PreheaderBr = cast(Preheader->getTerminator()); + BasicBlock *Tmp = PreheaderBr->getSuccessor(0); + PreheaderBr->setSuccessor(0, Header); + DTU.applyUpdatesPermissive({ + {DominatorTree::Delete, Preheader, Tmp}, + {DominatorTree::Insert, Header, Body}, + {DominatorTree::Insert, Body, Latch}, + {DominatorTree::Insert, Latch, Header}, + {DominatorTree::Insert, Latch, Exit}, + {DominatorTree::Insert, Preheader, Header}, + }); + + L->addBasicBlockToLoop(Header, LI); + L->addBasicBlockToLoop(Body, LI); + L->addBasicBlockToLoop(Latch, LI); + return Body; +} + +// Creates the following loop nest skeleton: +// for C = 0; C < NumColumns; C += TileSize +// for R = 0; R < NumRows; R += TileSize +// for K = 0; K < Inner ; K += TileSize +BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, + IRBuilder<> &B, DomTreeUpdater &DTU, + LoopInfo &LI) { + Loop *ColLoop = LI.AllocateLoop(); + Loop *RowLoop = LI.AllocateLoop(); + Loop *InnerLoop = LI.AllocateLoop(); + RowLoop->addChildLoop(InnerLoop); + ColLoop->addChildLoop(RowLoop); + if (Loop *ParentL = LI.getLoopFor(Start)) + ParentL->addChildLoop(ColLoop); + else + LI.addTopLevelLoop(ColLoop); + + BasicBlock *ColBody = + CreateLoop(Start, End, B.getInt32(NumColumns), B.getInt32(TileSize), + "cols", B, DTU, ColLoop, LI); + BasicBlock *ColLatch = ColBody->getSingleSuccessor(); + BasicBlock *RowBody = + CreateLoop(ColBody, ColLatch, B.getInt32(NumRows), B.getInt32(TileSize), + "rows", B, DTU, RowLoop, LI); + RowLoopLatch = RowBody->getSingleSuccessor(); + + BasicBlock *InnerBody = + CreateLoop(RowBody, RowLoopLatch, B.getInt32(NumInner), + B.getInt32(TileSize), "inner", B, DTU, InnerLoop, LI); + InnerLoopLatch = InnerBody->getSingleSuccessor(); + ColumnLoopHeader = ColBody->getSinglePredecessor(); + RowLoopHeader = RowBody->getSinglePredecessor(); + InnerLoopHeader = InnerBody->getSinglePredecessor(); + CurrentRow = &*RowLoopHeader->begin(); + CurrentCol = &*ColumnLoopHeader->begin(); + CurrentK = &*InnerLoopHeader->begin(); + + return InnerBody; +}