diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -15,6 +15,12 @@ class ModuleOp; template class OpPassBase; +/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix +/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics +/// will be needed when invoking LLVM. +void populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -836,12 +836,12 @@ : LLVM_OneResultOp<"intr.matrix.multiply">, Arguments<( ins LLVM_Type:$lhs, LLVM_Type:$rhs, - I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> { + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> { string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); $res = mb.CreateMatrixMultiply( $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(), - $rhs_rows.getZExtValue()); + $rhs_columns.getZExtValue()); }]; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1336,4 +1336,65 @@ let assemblyFormat = "$source attr-dict `:` type($source)"; } +//===----------------------------------------------------------------------===// +// Ops used for supporting progressive lowering and conversion type changes. +//===----------------------------------------------------------------------===// + +/// Vector dialect matrix multiplication op that operates on flattened 1-D +/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR. +/// This may seem redundant with vector.contract but it serves the purposes of +/// more progressive lowering and localized type conversion on the path: +/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. +def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<( + // TODO(ntv, fhahn): tighten vector element types that make sense. + ins VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs, + VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs, + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, + Results<( + outs VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> +{ + let summary = "Vector matrix multiplication op that operates on flattened 1-D" + " MLIR vectors"; + let description = [{ + This is the counterpart of llvm.matrix.multiply in MLIR. It serves the + purposes of more progressive lowering and localized type conversion. + + The ‘vector.matrix_multiply’ op treats `lhs` as matrix with rows + and columns, `rhs` as matrix with rows and + and multiplies them. The result matrix is returned embedded in + the result vector. + + Example: + + ``` + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + ``` + }]; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, " + "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns", + [{ + result.addOperands({lhs, rhs}); + result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows)); + result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns)); + result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns)); + result.addTypes(VectorType::get(lhsRows * lhsColumns, + lhs.getType().cast().getElementType())); + }]>, + ]; + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs attr-dict " + "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -275,6 +275,28 @@ } }; +/// Conversion pattern for a vector.matrix_multiply. +/// This is lowered directly to the proper llvm.intr.matrix.multiply. +class VectorMatmulOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMatmulOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto matmulOp = cast(op); + auto adaptor = vector::MatmulOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), + adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), + matmulOp.rhs_columns()); + return matchSuccess(); + } +}; + class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1141,6 +1163,12 @@ VectorPrintOpConversion>(ctx, converter); } +void mlir::populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx, converter); +} + namespace { struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; @@ -1160,6 +1188,7 @@ // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -701,3 +701,15 @@ // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i64 + +// 4x16 16x3 4x3 +func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + return %C: vector<12xf64> +} +// CHECK-LABEL: llvm.func @matrix_ops +// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { +// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32 +// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>"> diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -136,7 +136,7 @@ %ptr: !llvm<"float*">, %stride: !llvm.i32) { // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) %C = llvm.intr.matrix.multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} : + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>"> // CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16) %D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} :