diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -246,7 +246,7 @@ void setVector(unsigned i, Value *V) { Vectors[i] = V; } - Type *getElementType() { return getVectorTy()->getElementType(); } + Type *getElementType() const { return getVectorTy()->getElementType(); } unsigned getNumVectors() const { if (isColumnMajor()) @@ -276,7 +276,7 @@ return getVectorTy(); } - VectorType *getVectorTy() { + VectorType *getVectorTy() const { return cast(Vectors[0]->getType()); } @@ -1370,6 +1370,8 @@ const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); + assert(Lhs.getElementType() == Rhs.getElementType() && + "Matrix multiply argument element types do not match."); const unsigned R = LShape.NumRows; const unsigned C = RShape.NumColumns; @@ -1377,6 +1379,8 @@ // Initialize the output MatrixTy Result(R, C, EltType); + assert(Lhs.getElementType() == Result.getElementType() && + "Matrix multiply result element type does not match arguments."); bool AllowContract = AllowContractEnabled || (isa(MatMul) && MatMul->hasAllowContract());