diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -764,6 +764,74 @@ def LLVM_experimental_vector_reduce_v2_fadd : LLVM_VectorReductionV2<"fadd">; def LLVM_experimental_vector_reduce_v2_fmul : LLVM_VectorReductionV2<"fmul">; +// +// LLVM Matrix operations. +// + +// TODO(ntv, zinenko): Provide more idiomatic tblgen support. +// OTOH there are 4 intrinsics total atm so.. +// TODO(ntv): Use this once LLVM integrate landed. +// def LLVM_MatrixMultiplyOp +// : LLVM_OneResultOp<"intr.matrix.multiply">, +// Arguments<( +// ins LLVM_Type:$A, LLVM_Type:$B, I32Attr:$M, I32Attr:$N)> { +// string llvmBuilder = +// "$res = MatrixBuilder::CreateMatrixMultiply(" +// "$A, $B, $M.getZExtValue(), $N.getZExtValue()));"; +// let builders = [ +// OpBuilder<"Builder *builder, OperationState &result, Value A, Value B, " +// "IntegerAttr M, IntegerAttr N", [{ +// result.addAttribute("M", M); +// result.addAttribute("N", N); +// result.addTypes(VectorType::get( +// A.getType().cast()->getElementType(), +// M.getZExtValue() * N.getZExtValue()));}]>, +// ]; +// let assemblyFormat = +// "$A `,` $B attr-dict `:` type($A) `,` type($B) `->` type($res)"; +// } + +def LLVM_MatrixMultiplyOp + : LLVM_OneResultOp<"intr.matrix.multiply">, + Arguments<( + ins LLVM_Type:$A, LLVM_Type:$B, I32Attr:$M, I32Attr:$N, I32Attr:$K)> { + string llvmBuilder = [{ + $res = builder.CreateCall( + llvm::Intrinsic::getDeclaration( + builder.GetInsertBlock()->getModule(), + llvm::Intrinsic::matrix_multiply, + { + llvm::VectorType::get( + $A->getType()->getScalarType(), + $M.getZExtValue() * $N.getZExtValue()), + $A->getType(), + $B->getType() + } + ), + { + $A, + $B, + llvm::ConstantInt::get(builder.getInt32Ty(), $M.getZExtValue()), + // LLVM call is M, K, N .. go figure + llvm::ConstantInt::get(builder.getInt32Ty(), $K.getZExtValue()), + llvm::ConstantInt::get(builder.getInt32Ty(), $N.getZExtValue()) + } + ); + }]; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, LLVMType resType, " + "Value A, Value B, IntegerAttr M, IntegerAttr N, IntegerAttr K", + [{ + result.addAttribute("M", M); + result.addAttribute("N", N); + result.addAttribute("K", K); + result.addTypes(resType);}]>, + ]; + + let assemblyFormat = + "$A `,` $B attr-dict `:` type($A) `,` type($B) `->` type($res)"; +} + // // Atomic operations. // diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -23,6 +23,8 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +// TODO(ntv): enable +// #include "llvm/IR/MatrixBuilder.h" namespace mlir { class Attribute; diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -130,6 +130,18 @@ llvm.return } +// CHECK-LABEL: @matrix_intrinsics +// 4x16 16x3 +llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">) +// 4x3 + -> !llvm<"<12 x float>"> +{ + // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) + %C = llvm.intr.matrix.multiply %A, %B { M = 4: i32, N = 3: i32, K = 16: i32 } : + !llvm<"<64 x float>">, !llvm<"<48 x float>"> -> !llvm<"<12 x float>"> + llvm.return %C: !llvm<"<12 x float>"> +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -153,3 +165,4 @@ // CHECK-DAG: declare float @llvm.cos.f32(float) // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare float @llvm.copysign.f32(float, float) +// CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)