diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -183,6 +183,39 @@ }]; } +def Vector_ReductionOp : + Vector_Op<"reduction", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins StrAttr:$kind, AnyVector:$vector)>, + Results<(outs AnyType:$dest)> { + let summary = "reduction operation"; + let description = [{ + Reduces an 1-D vector "horizontally" into a scalar using the given + operation (add/mul/min/max for int/fp and and/or/xor for int only). + Note that these operations are restricted to 1-D vectors to remain + close to the corresponding LLVM intrinsics: + + http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics + + Examples: + ``` + %1 = vector.reduction "add", %0 : vector<16xf32> into f32 + + %3 = vector.reduction "xor", %2 : vector<4xi32> into i32 + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let assemblyFormat = [{ + $kind `,` $vector attr-dict `:` type($vector) `into` type($dest) + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return vector().getType().cast(); + } + }]; +} + def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -124,6 +124,7 @@ } namespace { + class VectorBroadcastOpConversion : public LLVMOpLowering { public: explicit VectorBroadcastOpConversion(MLIRContext *context, @@ -272,6 +273,73 @@ } }; +class VectorReductionOpConversion : public LLVMOpLowering { +public: + explicit VectorReductionOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::ReductionOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast(op); + auto kind = reductionOp.kind(); + Type eltType = reductionOp.dest().getType(); + Type llvmType = lowering.convertType(eltType); + if (eltType.isInteger(32) || eltType.isInteger(64)) { + // Integer reductions: add/mul/min/max/and/or/xor. + if (kind == "add") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "mul") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "min") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "and") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "or") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "xor") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + + } else if (eltType.isF32() || eltType.isF64()) { + // Floating-point reductions: add/mul/min/max + if (kind == "add") { + Value zero = rewriter.create( + op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); + rewriter.replaceOpWithNewOp( + op, llvmType, zero, operands[0]); + } else if (kind == "mul") { + Value one = rewriter.create( + op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); + rewriter.replaceOpWithNewOp( + op, llvmType, one, operands[0]); + } else if (kind == "min") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp( + op, llvmType, operands[0]); + else + return matchFailure(); + return matchSuccess(); + } + return matchFailure(); + } +}; + class VectorShuffleOpConversion : public LLVMOpLowering { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1056,12 +1124,12 @@ VectorInsertStridedSliceOpDifferentRankRewritePattern, VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); - patterns.insert( - ctx, converter); + patterns.insert(ctx, converter); } namespace { diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -61,6 +61,33 @@ } //===----------------------------------------------------------------------===// +// ReductionOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReductionOp op) { + // Verify for 1-D vector. + int64_t rank = op.getVectorType().getRank(); + if (rank != 1) + return op.emitOpError("unsupported reduction rank: ") << rank; + + // Verify supported reduction kind. + auto kind = op.kind(); + Type eltType = op.dest().getType(); + if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { + if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) || + eltType.isInteger(64)) + return success(); + return op.emitOpError("unsupported reduction type"); + } + if (kind == "and" || kind == "or" || kind == "xor") { + if (eltType.isInteger(32) || eltType.isInteger(64)) + return success(); + return op.emitOpError("unsupported reduction type"); + } + return op.emitOpError("unknown reduction kind: ") << kind; +} + +//===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -645,7 +645,7 @@ // CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) : // CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> %0 = vector.fma %a, %a, %a : vector<8xf32> - + // CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> @@ -659,7 +659,45 @@ // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]"> %1 = vector.fma %b, %b, %b : vector<2x4xf32> - + return %0, %1: vector<8xf32>, vector<2x4xf32> } - + +func @reduce_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} +// CHECK-LABEL: llvm.func @reduce_f32 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.float + +func @reduce_f64(%arg0: vector<16xf64>) -> f64 { + %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64 + return %0 : f64 +} +// CHECK-LABEL: llvm.func @reduce_f64 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>"> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.double + +func @reduce_i32(%arg0: vector<16xi32>) -> i32 { + %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32 + return %0 : i32 +} +// CHECK-LABEL: llvm.func @reduce_i32 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>"> +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.i32 + +func @reduce_i64(%arg0: vector<16xi64>) -> i64 { + %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64 + return %0 : i64 +} +// CHECK-LABEL: llvm.func @reduce_i64 +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>"> +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.i64 + diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -990,3 +990,31 @@ %1 = vector.shape_cast %arg1 : tuple, vector<3x4x2xf32>> to tuple> } + +// ----- + +func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unknown reduction kind: joho}} + %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 { + // expected-error@+1 {{'vector.reduction' op failed to verify that source operand and result have same element type}} + %0 = vector.reduction "add", %arg0 : vector<16xf32> into i32 +} + +// ----- + +func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unsupported reduction type}} + %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 { + // expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}} + %0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32 +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -277,3 +277,37 @@ vector.fma %b, %b, %b : vector<8x4xf32> return } + +// CHECK-LABEL: reduce_fp +func @reduce_fp(%arg0: vector<16xf32>) -> f32 { + // CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32 + vector.reduction "add", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "mul", %{{.*}} : vector<16xf32> into f32 + vector.reduction "mul", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32 + vector.reduction "min", %arg0 : vector<16xf32> into f32 + // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32 + %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32 + // CHECK: return %[[X]] : f32 + return %0 : f32 +} + +// CHECK-LABEL: reduce_int +func @reduce_int(%arg0: vector<16xi32>) -> i32 { + // CHECK: vector.reduction "add", %{{.*}} : vector<16xi32> into i32 + vector.reduction "add", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32 + vector.reduction "mul", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32 + vector.reduction "min", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32 + vector.reduction "max", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32 + vector.reduction "and", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32 + vector.reduction "or", %arg0 : vector<16xi32> into i32 + // CHECK: %[[X:.*]] = vector.reduction "xor", %{{.*}} : vector<16xi32> into i32 + %0 = vector.reduction "xor", %arg0 : vector<16xi32> into i32 + // CHECK: return %[[X]] : i32 + return %0 : i32 +}