diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -15,6 +15,12 @@ class ModuleOp; template class OpPassBase; +/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix +/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics +/// will be needed wheninvoking LLVM. +void populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -797,12 +797,12 @@ : LLVM_OneResultOp<"intr.matrix.multiply">, Arguments<( ins LLVM_Type:$lhs, LLVM_Type:$rhs, - I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> { + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> { string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); $res = mb.CreateMatrixMultiply( $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(), - $rhs_rows.getZExtValue()); + $rhs_columns.getZExtValue()); }]; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1323,4 +1323,57 @@ let assemblyFormat = "$source attr-dict `:` type($source)"; } +//===----------------------------------------------------------------------===// +// Ops used for supporting progressive lowering and conversion type changes. +//===----------------------------------------------------------------------===// + +/// Vector dialect matrix multiplication op that operates on flattened 1-D +/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR. +/// This may seem redundant with vector.contract but it serves the purposes of +/// more progressive lowering and localized type conversion on the path: +/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. +def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<( + // TODO(ntv, fhahn): tighten vector element types that make sense. + ins VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs, + VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs, + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, + Results<( + outs VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> +{ + let summary = "Vector matrix multiplication op that operates on flattened 1-D" + " MLIR vectors"; + let description = [{ + This is the counterpart of llvm.matrix.multiply in MLIR. + This may seem redundant with vector.contract but it serves the purposes of + more progressive lowering and localized type conversion on the path: + `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. + + Example: + + }]; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, " + "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns", + [{ + result.addOperands({lhs, rhs}); + result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows)); + result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns)); + result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns)); + result.addTypes(VectorType::get(lhsRows * lhsColumns, + lhs.getType().cast().getElementType())); + }]>, + ]; + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs attr-dict " + "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; +} + #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -485,6 +485,29 @@ class VectorOf allowedTypes> : ShapedContainerType; +// Whether the number of elements of a vector is from the given +// `allowedRanks` list +class IsVectorOfRankPred allowedRanks> : + And<[IsVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + +// Any vector where the rank is from the given `allowedRanks` list +class VectorOfRank allowedRanks> : Type< + IsVectorOfRankPred, + " of ranks " # StrJoinInt.result>; + +// Any vector where the rank is from the given `allowedRanks` list and the type +// is from the given `allowedTypes` list +class VectorOfRankAndType allowedRanks, + list allowedTypes> : Type< + And<[VectorOf.predicate, + VectorOfRank.predicate]>, + VectorOf.description # + VectorOfRank.description>; + // Whether the number of elements of a vector is from the given // `allowedLengths` list class IsVectorOfLengthPred allowedLengths> : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -275,6 +275,28 @@ } }; +/// Conversion pattern for a vector.matrix_multiply. +/// This is lowered directly to the proper llvm.intr.matrix.multiply. +class VectorMatmulOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMatmulOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto matmulOp = cast(op); + auto adaptor = vector::MatmulOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), + adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), + matmulOp.rhs_columns()); + return matchSuccess(); + } +}; + class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1141,6 +1163,12 @@ VectorPrintOpConversion>(ctx, converter); } +void mlir::populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx, converter); +} + namespace { struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; @@ -1160,6 +1188,7 @@ // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -701,3 +701,15 @@ // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i64 + +// 4x16 16x3 4x3 +func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + return %C: vector<12xf64> +} +// CHECK-LABEL: llvm.func @matrix_ops +// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { +// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32 +// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>"> diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -138,7 +138,7 @@ { // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) %C = llvm.intr.matrix.multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} : + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>"> llvm.return %C: !llvm<"<12 x float>"> }