diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -768,6 +768,27 @@ def LLVM_experimental_vector_reduce_v2_fadd : LLVM_VectorReductionV2<"fadd">; def LLVM_experimental_vector_reduce_v2_fmul : LLVM_VectorReductionV2<"fmul">; +// +// LLVM Matrix operations. +// + +/// As specified in the LLVM MatrixBuilder: +/// Create a llvm.matrix.multiply call, multiplying matrices LHS and RHS. +def LLVM_MatrixMultiplyOp + : LLVM_OneResultOp<"intr.matrix.multiply">, + Arguments<( + ins LLVM_Type:$lhs, LLVM_Type:$rhs, + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> { + string llvmBuilder = [{ + llvm::MatrixBuilder mb(builder); + $res = mb.CreateMatrixMultiply( + $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(), + $rhs_rows.getZExtValue()); + }]; + let assemblyFormat = "$lhs `,` $rhs attr-dict " + "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; +} + // // Atomic operations. // diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -23,6 +23,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/Value.h" namespace mlir { diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -130,6 +130,19 @@ llvm.return } +// CHECK-LABEL: @matrix_intrinsics +// 4x16 16x3 +llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">) +// 4x3 + -> !llvm<"<12 x float>"> +{ + // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) + %C = llvm.intr.matrix.multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} : + (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>"> + llvm.return %C: !llvm<"<12 x float>"> +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -153,3 +166,4 @@ // CHECK-DAG: declare float @llvm.cos.f32(float) // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare float @llvm.copysign.f32(float, float) +// CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)