diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -5638,15 +5638,28 @@ Type *Op0ElemTy = nullptr; Type *Op1ElemTy = nullptr; switch (ID) { - case Intrinsic::matrix_multiply: + case Intrinsic::matrix_multiply: { NumRows = cast(Call.getArgOperand(2)); + ConstantInt *N = cast(Call.getArgOperand(3)); NumColumns = cast(Call.getArgOperand(4)); + Check(cast(Call.getArgOperand(0)->getType()) + ->getNumElements() == + NumRows->getZExtValue() * N->getZExtValue(), + "First argument of a matrix operation does not match specified " + "shape!"); + Check(cast(Call.getArgOperand(1)->getType()) + ->getNumElements() == + N->getZExtValue() * NumColumns->getZExtValue(), + "Second argument of a matrix operation does not match specified " + "shape!"); + ResultTy = cast(Call.getType()); Op0ElemTy = cast(Call.getArgOperand(0)->getType())->getElementType(); Op1ElemTy = cast(Call.getArgOperand(1)->getType())->getElementType(); break; + } case Intrinsic::matrix_transpose: NumRows = cast(Call.getArgOperand(1)); NumColumns = cast(Call.getArgOperand(2)); diff --git a/llvm/test/Verifier/matrix-intrinsics.ll b/llvm/test/Verifier/matrix-intrinsics.ll --- a/llvm/test/Verifier/matrix-intrinsics.ll +++ b/llvm/test/Verifier/matrix-intrinsics.ll @@ -1,4 +1,4 @@ -; RUN: not llvm-as -opaque-pointers < %s -o /dev/null 2>&1 | FileCheck %s +; RUN: not llvm-as < %s -o /dev/null 2>&1 | FileCheck %s define <4 x float> @transpose(<4 x float> %m, i32 %arg) { ; CHECK: assembly parsed, but does not verify as correct! @@ -20,17 +20,19 @@ } define <4 x float> @multiply(<4 x float> %m, i32 %arg) { -; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector! -; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector! +; CHECK-NEXT: First argument of a matrix operation does not match specified shape! +; CHECK-NEXT: First argument of a matrix operation does not match specified shape! +; CHECK-NEXT: Second argument of a matrix operation does not match specified shape! ; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector! ; CHECK-NEXT: immarg operand has non-immediate parameter ; CHECK-NEXT: i32 %arg -; CHECK-NEXT: %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1) +; CHECK-NEXT: %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1) %result.0 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %m, i32 0, i32 0, i32 0) %result.1 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.0, <4 x float> %m, i32 3, i32 2, i32 2) %result.2 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.1, <4 x float> %m, i32 2, i32 2, i32 1) - %result.3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1) - ret <4 x float> %result.3 + %result.3 = call <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 2, i32 2, i32 2) + %result.4 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %result.2, <4 x float> %m, i32 %arg, i32 2, i32 1) + ret <4 x float> %result.4 } define <4 x float> @column.major_load(ptr %m, ptr %n, i32 %arg) { @@ -136,3 +138,4 @@ declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4i32(<4 x float>, <4 x i32>, i32, i32, i32) declare <4 x float> @llvm.matrix.multiply.v4f32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32) declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32) +declare <3 x float> @llvm.matrix.multiply.v3f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)