diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -388,6 +388,38 @@ }]; } +def Vector_FMAOp : + Op]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, + Results<(outs AnyVector:$result)> { + let summary = "vector fused multiply-add"; + let description = [{ + Multiply-add expressions operate on n-D vectors and compute a fused + pointwise multiply-and-accumulate: `$result = `$lhs * $rhs + $acc`. + All operands and result have the same vector type. The semantics + of the operation correspond to those of the `llvm.fma` + [intrinsic](https://llvm.org/docs/LangRef.html#int-fma). In the + particular case of lowering to LLVM, this is guaranteed to lower + to the `llvm.fma.*` intrinsic. + + Example: + + ``` + %3 = vector.fma %0, %1, %2: vector<8x16xf32> + ``` + }]; + // Fully specified by traits. + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc", + "build(b, result, lhs.getType(), lhs, rhs, acc);">]; + let extraClassDeclaration = [{ + VectorType getVectorType() { return lhs().getType().cast(); } + }]; +} + def Vector_InsertElementOp : Vector_Op<"insertelement", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -410,6 +410,41 @@ } }; +/// Conversion pattern that turns a vector.fma on a 1-D vector +/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. +/// This does not match vectors of n >= 2 rank. +/// +/// Example: +/// ``` +/// vector.fma %a, %a, %a : vector<8xf32> +/// ``` +/// is converted to: +/// ``` +/// llvm.intr.fma %va, %va, %va: +/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) +/// -> !llvm<"<8 x float>"> +/// ``` +class VectorFMAOp1DConversion : public LLVMOpLowering { +public: + explicit VectorFMAOp1DConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::FMAOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::FMAOpOperandAdaptor(operands); + vector::FMAOp fmaOp = cast(op); + VectorType vType = fmaOp.getVectorType(); + if (vType.getRank() != 1) + return matchFailure(); + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs(), + adaptor.acc()); + return matchSuccess(); + } +}; + class VectorInsertElementOpConversion : public LLVMOpLowering { public: explicit VectorInsertElementOpConversion(MLIRContext *context, @@ -502,6 +537,54 @@ } }; +/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. +/// +/// Example: +/// ``` +/// %d = vector.fma %a, %b, %c : vector<2x4xf32> +/// ``` +/// is rewritten into: +/// ``` +/// %r = splat %f0: vector<2x4xf32> +/// %va = vector.extractvalue %a[0] : vector<2x4xf32> +/// %vb = vector.extractvalue %b[0] : vector<2x4xf32> +/// %vc = vector.extractvalue %c[0] : vector<2x4xf32> +/// %vd = vector.fma %va, %vb, %vc : vector<4xf32> +/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> +/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> +/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> +/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> +/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> +/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> +/// // %r3 holds the final value. +/// ``` +class VectorFMAOpNDRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(FMAOp op, + PatternRewriter &rewriter) const override { + auto vType = op.getVectorType(); + if (vType.getRank() < 2) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = vType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, vType, zero); + for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { + Value extrLHS = rewriter.create(loc, op.lhs(), i); + Value extrRHS = rewriter.create(loc, op.rhs(), i); + Value extrACC = rewriter.create(loc, op.acc(), i); + Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); + desc = rewriter.create(loc, fma, desc, i); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + // When ranks are different, InsertStridedSlice needs to extract a properly // ranked vector from the destination vector into which to insert. This pattern // only takes care of this part and forwards the rest of the conversion to @@ -969,14 +1052,16 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); patterns.insert(ctx, converter); + VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorInsertOpConversion, VectorOuterProductOpConversion, + VectorTypeCastOpConversion, VectorPrintOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -637,3 +637,29 @@ // CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> // CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> // CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]"> + +// CHECK-LABEL: llvm.func @vector_fma( +// CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">) +// CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> { +func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) { + // CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) : + // CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> + %0 = vector.fma %a, %a, %a : vector<8xf32> + + // CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) : + // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]"> + // CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) : + // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]"> + %1 = vector.fma %b, %b, %b : vector<2x4xf32> + + return %0, %1: vector<8xf32>, vector<2x4xf32> +} + diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -268,3 +268,12 @@ return %0, %1 : vector<15x2xf32>, tuple, vector<12x2xf32>> } + +// CHECK-LABEL: @vector_fma +func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { + // CHECK: vector.fma %{{.*}} : vector<8xf32> + vector.fma %a, %a, %a : vector<8xf32> + // CHECK: vector.fma %{{.*}} : vector<8x4xf32> + vector.fma %b, %b, %b : vector<8x4xf32> + return +}