diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -0,0 +1,172 @@ +//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the MatrixBuilder class, which is used as a convenient way +// to lower matrix operations to LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_MATRIXBUILDER_H +#define LLVM_IR_MATRIXBUILDER_H + +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" + +namespace llvm { + +class Function; +class Twine; +class Module; + +template class MatrixBuilder { + IRBuilderTy &B; + Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + +public: + MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} + + /// Create a columnwise, strided matrix load. + /// \p DataPtr - Start address of the matrix read + /// \p Rows - Number of rows in matrix (must be a constant) + /// \p Columns - Number of columns in matrix (must be a constant) + /// \p Stride - Space between columns + CallInst *CreateMatrixColumnwiseLoad(Value *DataPtr, unsigned Rows, + unsigned Columns, Value *Stride, + const Twine &Name = "") { + + // Deal with the pointer + PointerType *PtrTy = cast(DataPtr->getType()); + Type *EltTy = PtrTy->getElementType(); + + Type *RetType = VectorType::get(EltTy, Rows * Columns); + + Value *Ops[] = {DataPtr, Stride, B.getInt32(Rows), B.getInt32(Columns)}; + Type *OverloadedTypes[] = {RetType, PtrTy}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_columnwise_load, OverloadedTypes); + + return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); + } + + /// Create a columnwise, strided matrix store. + /// \p Matrix - Matrix to store + /// \p Ptr - Pointer to write back to + /// \p Stride - Space between columns + CallInst *CreateMatrixColumnwiseStore(Value *Matrix, Value *Ptr, + Value *Stride, unsigned Rows, + unsigned Columns, + const Twine &Name = "") { + Value *Ops[] = {Matrix, Ptr, Stride, B.getInt32(Rows), B.getInt32(Columns)}; + Type *OverloadedTypes[] = {Matrix->getType(), Ptr->getType()}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_columnwise_store, OverloadedTypes); + + return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); + } + + /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows + /// rows and \p Columns columns. + CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, + unsigned Columns, const Twine &Name = "") { + auto *OpType = cast(Matrix->getType()); + Type *ReturnType = + VectorType::get(OpType->getElementType(), Rows * Columns); + + Type *OverloadedTypes[] = {ReturnType}; + Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_transpose, OverloadedTypes); + + return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); + } + + /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p + /// RHS. + CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, + unsigned LHSColumns, unsigned RHSColumns, + const Twine &Name = "") { + auto *LHSType = cast(LHS->getType()); + auto *RHSType = cast(RHS->getType()); + + Type *ReturnType = + VectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); + + Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), + B.getInt32(RHSColumns)}; + Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_multiply, OverloadedTypes); + return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); + } + + /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p + /// ColumnIdx). + Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, + Value *ColumnIdx, unsigned NumRows) { + return B.CreateInsertElement( + Matrix, NewVal, + B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( + ColumnIdx->getType(), NumRows)), + RowIdx)); + } + + /// Add matrixes \p LHS and \p RHS. Support both integer and floating point + /// matrixes. + Value *CreateAdd(Value *LHS, Value *RHS) { + return cast(LHS->getType()) + ->getElementType() + ->isFloatingPointTy() + ? B.CreateFAdd(LHS, RHS) + : B.CreateAdd(LHS, RHS); + } + + /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating + /// point matrixes. + Value *CreateSub(Value *LHS, Value *RHS) { + return cast(LHS->getType()) + ->getElementType() + ->isFloatingPointTy() + ? B.CreateFSub(LHS, RHS) + : B.CreateSub(LHS, RHS); + } + + /// Multiply matrix \p LHS with scalar \p RHS. + Value *CreateScalarMultiply(Value *LHS, Value *RHS) { + Value *ScalarVector = + B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), + RHS, "scalar.splat"); + if (RHS->getType()->isFloatingPointTy()) + return B.CreateFMul(LHS, ScalarVector); + + return B.CreateMul(LHS, ScalarVector); + } + + /// Extracts the element at (\p Row, \p Column) from \p Matrix. + Value *CreateExtractMatrix(Value *Matrix, Value *Row, Value *Column, + unsigned NumRows, Twine const &Name = "") { + + return B.CreateExtractElement( + Matrix, + B.CreateAdd( + B.CreateMul(Column, ConstantInt::get(Column->getType(), NumRows)), + Row)); + } +}; + +} // end namespace llvm + +#endif // LLVM_IR_MATRIXBUILDER_H