diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -187,12 +187,15 @@ Vector_Op<"reduction", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<(ins StrAttr:$kind, AnyVector:$vector)>, + Arguments<(ins StrAttr:$kind, AnyVector:$vector, Variadic:$acc)>, Results<(outs AnyType:$dest)> { let summary = "reduction operation"; let description = [{ Reduces an 1-D vector "horizontally" into a scalar using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). + Some reductions (add/mul for fp) also allow an optional fused + accumulator. + Note that these operations are restricted to 1-D vectors to remain close to the corresponding LLVM intrinsics: @@ -203,34 +206,9 @@ %1 = vector.reduction "add", %0 : vector<16xf32> into f32 %3 = vector.reduction "xor", %2 : vector<4xi32> into i32 - ``` - }]; - let verifier = [{ return ::verify(*this); }]; - let assemblyFormat = [{ - $kind `,` $vector attr-dict `:` type($vector) `into` type($dest) - }]; - let extraClassDeclaration = [{ - VectorType getVectorType() { - return vector().getType().cast(); - } - }]; -} -// TODO(ajcbik): quick version with "fused" accumulator; next step -// will merge Reduction/ReductionV2 into one with -// an optional accumulator instead -def Vector_ReductionV2Op : - Vector_Op<"reductionv2", [NoSideEffect]>, - Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>, - Results<(outs AnyType:$dest)> { - let summary = "reduction operation"; - let description = [{ - As vector.reduction, but with a fused accumulator (add/mul for fp only). - }]; - let verifier = ?; - let assemblyFormat = [{ - $kind `,` $vector `,` $acc attr-dict `:` - type($vector) `,` type($acc) `into` type($dest) + %4 = vector.reduction "mul", %0, %1 : vector<16xf32> into f32 + ``` }]; let extraClassDeclaration = [{ VectorType getVectorType() { @@ -469,7 +447,7 @@ to the `llvm.fma.*` intrinsic. Example: - + ``` %3 = vector.fma %0, %1, %2: vector<8x16xf32> ``` diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -319,15 +319,22 @@ } else if (eltType.isF32() || eltType.isF64()) { // Floating-point reductions: add/mul/min/max if (kind == "add") { - Value zero = rewriter.create( - op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); + // Optional accumulator (or zero). + Value acc = operands.size() > 1 ? operands[1] + : rewriter.create( + op->getLoc(), llvmType, + rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( - op, llvmType, zero, operands[0]); + op, llvmType, acc, operands[0]); } else if (kind == "mul") { - Value one = rewriter.create( - op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); + // Optional accumulator (or one). + Value acc = operands.size() > 1 + ? operands[1] + : rewriter.create( + op->getLoc(), llvmType, + rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( - op, llvmType, one, operands[0]); + op, llvmType, acc, operands[0]); } else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); @@ -342,33 +349,6 @@ } }; -// TODO(ajcbik): merge Reduction and ReductionV2 -class VectorReductionV2OpConversion : public ConvertToLLVMPattern { -public: - explicit VectorReductionV2OpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ReductionV2Op::getOperationName(), context, - typeConverter) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto reductionOp = cast(op); - auto kind = reductionOp.kind(); - Type eltType = reductionOp.dest().getType(); - Type llvmType = typeConverter.convertType(eltType); - if (kind == "add") { - rewriter.replaceOpWithNewOp( - op, llvmType, operands[1], operands[0]); - return matchSuccess(); - } else if (kind == "mul") { - rewriter.replaceOpWithNewOp( - op, llvmType, operands[1], operands[0]); - return matchSuccess(); - } - return matchFailure(); - } -}; - class VectorShuffleOpConversion : public ConvertToLLVMPattern { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1154,12 +1134,11 @@ VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); patterns.insert( - ctx, converter); + VectorShuffleOpConversion, VectorExtractElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, + VectorInsertElementOpConversion, VectorInsertOpConversion, + VectorOuterProductOpConversion, VectorTypeCastOpConversion, + VectorPrintOpConversion>(ctx, converter); } namespace { diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -74,17 +74,54 @@ auto kind = op.kind(); Type eltType = op.dest().getType(); if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { - if (eltType.isF32() || eltType.isF64() || eltType.isSignlessInteger(32) || - eltType.isSignlessInteger(64)) - return success(); - return op.emitOpError("unsupported reduction type"); + if (!eltType.isF32() && !eltType.isF64() && + !eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64)) + return op.emitOpError("unsupported reduction type"); + } else if (kind == "and" || kind == "or" || kind == "xor") { + if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64)) + return op.emitOpError("unsupported reduction type"); + } else { + return op.emitOpError("unknown reduction kind: ") << kind; } - if (kind == "and" || kind == "or" || kind == "xor") { - if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) - return success(); - return op.emitOpError("unsupported reduction type"); + + // Verify optional accumulator. + if (!op.acc().empty()) { + if (kind != "add" && kind != "mul") + return op.emitOpError("no accumulator for reduction kind: ") << kind; + if (!eltType.isF32() && !eltType.isF64()) + return op.emitOpError("no accumulator for type: ") << eltType; } - return op.emitOpError("unknown reduction kind: ") << kind; + + return success(); +} + +static ParseResult parseReductionOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operandsInfo; + Type redType; + Type resType; + Attribute attr; + if (parser.parseAttribute(attr, "kind", result.attributes) || + parser.parseComma() || parser.parseOperandList(operandsInfo) || + parser.parseColonType(redType) || + parser.parseKeywordType("into", resType) || + (operandsInfo.size() > 0 && + parser.resolveOperand(operandsInfo[0], redType, result.operands)) || + (operandsInfo.size() > 1 && + parser.resolveOperand(operandsInfo[1], resType, result.operands)) || + parser.addTypeToList(resType, result.types)) + return failure(); + if (operandsInfo.size() < 1 || operandsInfo.size() > 2) + return parser.emitError(parser.getNameLoc(), + "unsupported number of operands"); + return success(); +} + +static void print(OpAsmPrinter &p, ReductionOp op) { + p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector(); + if (!op.acc().empty()) + p << ", " << op.acc(); + p << " : " << op.vector().getType() << " into " << op.dest().getType(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -864,7 +864,7 @@ } }; -/// Progressive lowering of ConstractionOp. +/// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension /// is replaced by: @@ -1017,8 +1017,8 @@ Value zero = zeroVector(loc, lhsType, rewriter); Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, fma, - op.acc()); + return rewriter.create(loc, resType, kind, fma, + op.acc()); } // Construct new iterator types and affine map array attribute. SmallVector lowIndexingMaps; @@ -1067,9 +1067,8 @@ SmallVector results; for (auto it : llvm::enumerate(iteratorTypes)) { int64_t idx = it.index(); - if (idx == index) { + if (idx == index) continue; - } results.push_back(it.value()); } return results; diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -1007,6 +1007,34 @@ // ----- +func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 { + // expected-error@+1 {{'vector.reduction' op attribute 'kind' failed to satisfy constraint: string attribute}} + %0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32 +} + +// ----- + +func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 { + // expected-error@+1 {{'vector.reduction' unsupported number of operands}} + %0 = vector.reduction "add", %arg0, %arg1, %arg1 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 { + // expected-error@+1 {{'vector.reduction' op no accumulator for reduction kind: min}} + %0 = vector.reduction "min", %arg0, %arg1 : vector<16xf32> into f32 +} + +// ----- + +func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 { + // expected-error@+1 {{'vector.reduction' op no accumulator for type: 'i32'}} + %0 = vector.reduction "add", %arg0, %arg1 : vector<16xi32> into i32 +} + +// ----- + func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 { // expected-error@+1 {{'vector.reduction' op unsupported reduction type}} %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32 diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -279,11 +279,15 @@ } // CHECK-LABEL: reduce_fp -func @reduce_fp(%arg0: vector<16xf32>) -> f32 { +func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 { // CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32 vector.reduction "add", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "add", %{{.*}}, %{{.*}} : vector<16xf32> into f32 + vector.reduction "add", %arg0, %arg1 : vector<16xf32> into f32 // CHECK: vector.reduction "mul", %{{.*}} : vector<16xf32> into f32 vector.reduction "mul", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "mul", %{{.*}}, %{{.*}} : vector<16xf32> into f32 + vector.reduction "mul", %arg0, %arg1 : vector<16xf32> into f32 // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32 vector.reduction "min", %arg0 : vector<16xf32> into f32 // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32 diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -16,7 +16,7 @@ // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32> // CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]] +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32 // CHECK: return %[[R]] : f32 func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { @@ -44,12 +44,12 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> // CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> // CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: return %[[T9]] : vector<2xf32> @@ -80,12 +80,12 @@ // CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> // CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> // CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: return %[[T9]] : vector<2xf32> @@ -123,7 +123,7 @@ // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32> // CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32 // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32> // CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> // CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32> @@ -133,7 +133,7 @@ // CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32> // CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32> // CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32> -// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32 // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32> // CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32> // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> @@ -146,7 +146,7 @@ // CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32> // CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32> // CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32 // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32> // CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> // CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32> @@ -156,7 +156,7 @@ // CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32> // CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32> // CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32> -// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32 // CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32> // CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> // CHECK: return %[[T45]] : vector<2x2xf32> @@ -187,11 +187,11 @@ // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[C]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> // CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.reductionv2 "add", %[[T6]], %[[T3]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32 // CHECK: return %[[T7]] : f32 func @full_contract1(%arg0: vector<2x3xf32>, @@ -228,7 +228,7 @@ // CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reductionv2 "add", %[[T10]], %[[C]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> // CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32> @@ -240,7 +240,7 @@ // CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reductionv2 "add", %[[T22]], %[[T11]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 // CHECK: return %[[T23]] : f32 func @full_contract2(%arg0: vector<2x3xf32>,