diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -15,6 +15,11 @@ class ModuleOp; template class OpPassBase; +/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix +/// Intrinsics. This needs to go through memory atm. +void populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1323,4 +1323,45 @@ let assemblyFormat = "$source attr-dict `:` type($source)"; } +//===----------------------------------------------------------------------===// +// Ops used for supporting progressive lowering and conversion type changes. +//===----------------------------------------------------------------------===// + +/// Vector dialect matrix multiplication op that operates on flattened 1-D +/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR. +/// This may seem redundant with vector.contract but it serves the purposes of +/// more progressive lowering and localized type conversion on the path: +/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. +def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<( + // TODO(ntv, fhahn): tighten vector element types that make sense. + ins VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs, + VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs, + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)>, + Results<(outs VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> +{ + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, " + "unsigned lhsRows, unsigned lhsColumns, unsigned rhsRows", + [{ + result.addOperands({lhs, rhs}); + result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows)); + result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns)); + result.addAttribute("rhs_rows", builder->getI32IntegerAttr(rhsRows)); + result.addTypes(VectorType::get(lhsRows * lhsColumns, + lhs.getType().cast().getElementType())); + }]>, + ]; + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs attr-dict " + "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -275,6 +275,28 @@ } }; +/// Conversion pattern for a vector.matrix_multiply. +/// This is lowered directly to the proper llvm.intr.matrix.multiply. +class VectorMatmulOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMatmulOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto matmulOp = cast(op); + auto adaptor = vector::MatmulOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), + adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), + matmulOp.rhs_rows()); + return matchSuccess(); + } +}; + class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1141,6 +1163,12 @@ VectorPrintOpConversion>(ctx, converter); } +void mlir::populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx, converter); +} + namespace { struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; @@ -1160,6 +1188,7 @@ // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -701,3 +701,15 @@ // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i64 + +// 4x16 16x3 4x3 +func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + return %C: vector<12xf64> +} +// CHECK-LABEL: llvm.func @matrix_ops +// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { +// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_rows = 3 : i32 +// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">