diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -0,0 +1,169 @@ +//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the MatrixBuilder class, which is used as a convenient way +// to lower matrix operations to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_MATRIXBUILDER_H +#define LLVM_IR_MATRIXBUILDER_H + +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" + +namespace llvm { + +class Function; +class Twine; +class Module; + +template class MatrixBuilder { + IRBuilderTy &B; + Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + +public: + MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} + + /// \brief Create a matrix load intrinsic + /// \p DataPtr - Start address of the matrix read + /// \p Rows - Number of rows in matrix (must be a constant) + /// \p Columns - Number of columns in matrix (must be a constant) + /// \p Stride - Space between columns + CallInst *CreateMatrixColumnwiseLoad(Value *DataPtr, Value *Rows, + Value *Columns, Value *Stride, + unsigned ResultNumElementsFlattened, + const Twine &Name = "") { + + // Deal with the pointer + PointerType *PtrTy = cast(DataPtr->getType()); + Type *EltTy = PtrTy->getElementType(); + + assert(isa(Rows) && isa(Columns) && + "The number of rows and columns must be a compile-time constant"); + + Type *RetType = VectorType::get(EltTy, ResultNumElementsFlattened); + + Value *Ops[] = {DataPtr, Stride, Rows, Columns}; + Type *OverloadedTypes[] = {RetType, PtrTy}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_columnwise_load, OverloadedTypes); + + return B.createCallHelper(TheFn, Ops, &B, Name); + } + + /// \brief Create a matrix store intrinsic + /// \p Matrix - Matrix to store + /// \p Ptr - Pointer to write back to + /// \p Stride - Space between columns + CallInst *CreateMatrixColumnwiseStore(Value *Matrix, Value *Ptr, + Value *Stride, unsigned Rows, + unsigned Columns, + const Twine &Name = "") { + Value *Ops[] = {Matrix, Ptr, Stride, B.getInt32(Rows), B.getInt32(Columns)}; + Type *OverloadedTypes[] = {Matrix->getType(), Ptr->getType()}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_columnwise_store, OverloadedTypes); + + return B.createCallHelper(TheFn, Ops, &B, Name); + } + + CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, + unsigned Columns, + unsigned ResultNumElementsFlattened, + const Twine &Name = "") { + auto *OpType = cast(Matrix->getType()); + Type *ReturnType = + VectorType::get(OpType->getElementType(), ResultNumElementsFlattened); + + Type *OverloadedTypes[] = {ReturnType}; + Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_transpose, OverloadedTypes); + + return B.createCallHelper(TheFn, Ops, &B, Name); + } + + CallInst *CreateMatrixMultiply(Value *Operand1, Value *Operand2, + unsigned Op1Rows, unsigned Op1Columns, + unsigned Op2Columns, + unsigned ResultNumElementsFlattened, + const Twine &Name = "") { + auto *Op1Type = cast(Operand1->getType()); + auto *Op2Type = cast(Operand2->getType()); + + Type *ReturnType = + VectorType::get(Op1Type->getElementType(), ResultNumElementsFlattened); + + Value *Ops[] = {Operand1, Operand2, B.getInt32(Op1Rows), + B.getInt32(Op1Columns), B.getInt32(Op2Columns)}; + // ReturnType @llvm.matrix.multiply(Op1Type, Op2Type, i32, i32, i32) + Type *OverloadedTypes[] = {ReturnType, Op1Type, Op2Type}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_multiply, OverloadedTypes); + return B.createCallHelper(TheFn, Ops, &B, Name); + } + + Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, + Value *ColumnIdx, unsigned NumRows) { + return B.CreateInsertElement( + Matrix, NewVal, + B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( + ColumnIdx->getType(), NumRows)), + RowIdx)); + } + Value *CreateAdd(Value *LHS, Value *RHS) { + return cast(LHS->getType()) + ->getElementType() + ->isFloatingPointTy() + ? B.CreateFAdd(LHS, RHS) + : B.CreateAdd(LHS, RHS); + } + + Value *CreateSub(Value *LHS, Value *RHS) { + return cast(LHS->getType()) + ->getElementType() + ->isFloatingPointTy() + ? B.CreateFSub(LHS, RHS) + : B.CreateSub(LHS, RHS); + } + + Value *CreateScalarMultiply(Value *LHS, Value *RHS) { + Value *ScalarVector = + B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), + RHS, "scalar.splat"); + if (RHS->getType()->isFloatingPointTy()) + return B.CreateFMul(LHS, ScalarVector); + + return B.CreateMul(LHS, ScalarVector); + } + + /// Extracts the element at (\p Row, \p Column) from \p Matrix. + Value *CreateExtractMatrix(Value *Matrix, Value *Row, Value *Column, + unsigned NumRows, Twine const &Name = "") { + + return B.CreateExtractElement( + Matrix, + B.CreateAdd( + B.CreateMul(Column, ConstantInt::get(Column->getType(), NumRows)), + Row)); + } +}; + +} // end namespace llvm + +#endif // LLVM_IR_MATRIXBUILDER_H