diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -372,6 +372,33 @@ }]; } +def Vector_FMAOp : + Op]>, + Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>, + Results<(outs AnyVector:$result)> { + let summary = "vector fused multiply-add"; + let description = [{ + Multiply-add expressions that operates on n-D f32 or f64 vectors and lower + to the llvm.fmuladd.* intrinsic. + + Example: + + ``` + %3 = vector.fma %0, %1, %2: vector<8x16xf32> + ``` + }]; + // Fully specified by traits. + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc", + "build(b, result, lhs.getType(), lhs, rhs, acc);">]; + let extraClassDeclaration = [{ + VectorType getVectorType() { return lhs().getType().cast(); } + }]; +} + def Vector_InsertElementOp : Vector_Op<"insertelement", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -410,6 +410,27 @@ } }; +class VectorFMAOpConversion : public LLVMOpLowering { +public: + explicit VectorFMAOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::FMAOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::FMAOpOperandAdaptor(operands); + vector::FMAOp fmaOp = cast(op); + VectorType vType = fmaOp.getVectorType(); + if (vType.getRank() != 1) + return matchFailure(); + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs(), adaptor.acc()); + return matchSuccess(); + } +}; + class VectorInsertElementOpConversion : public LLVMOpLowering { public: explicit VectorInsertElementOpConversion(MLIRContext *context, @@ -502,6 +523,34 @@ } }; +// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. +class VectorFMAOpRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(FMAOp op, + PatternRewriter &rewriter) const override { + auto vType = op.getVectorType(); + if (vType.getRank() < 2) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = vType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, vType, zero); + for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { + Value extrLHS = rewriter.create(loc, op.lhs(), i); + Value extrRHS = rewriter.create(loc, op.rhs(), i); + Value extrACC = rewriter.create(loc, op.acc(), i); + Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); + desc = rewriter.create(loc, fma, desc, i); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + // When ranks are different, InsertStridedSlice needs to extract a properly // ranked vector from the destination vector into which to insert. This pattern // only takes care of this part and forwards the rest of the conversion to @@ -955,14 +1004,16 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); patterns.insert(ctx, converter); + VectorFMAOpConversion, VectorInsertElementOpConversion, + VectorInsertOpConversion, VectorOuterProductOpConversion, + VectorTypeCastOpConversion, VectorPrintOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -610,3 +610,19 @@ // CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> // CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> // CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]"> + +// CHECK-LABEL: llvm.func @vector_fma +func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) + -> (vector<8xf32>, vector<2x4xf32>) +{ + // CHECK: llvm.intr.fmuladd{{.*}}: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> + %0 = vector.fma %a, %a, %a : vector<8xf32> + // CHECK-COUNT-3: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <4 x float>]"> + // CHECK: llvm.intr.fmuladd{{.*}} : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <4 x float>]"> + // CHECK-COUNT-3: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <4 x float>]"> + // CHECK: llvm.intr.fmuladd{{.*}} : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <4 x float>]"> + %1 = vector.fma %b, %b, %b : vector<2x4xf32> + return %0, %1: vector<8xf32>, vector<2x4xf32> +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -233,3 +233,12 @@ return %1 : vector<2x3x4xf32> } + +// CHECK-LABEL: @vector_fma +func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { + // CHECK: vector.fma %{{.*}} : vector<8xf32> + vector.fma %a, %a, %a : vector<8xf32> + // CHECK: vector.fma %{{.*}} : vector<8x4xf32> + vector.fma %b, %b, %b : vector<8x4xf32> + return +}