diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2425,4 +2425,56 @@ let verifier = ?; } +//===----------------------------------------------------------------------===// +// VectorScanOp +//===----------------------------------------------------------------------===// + +def Vector_ScanOp : + Vector_Op<"scan", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins Vector_CombiningKindAttr:$kind, + AnyVector:$source, + AnyType:$identity, + I64ArrayAttr:$reduction_dims, + BoolAttr:$inclusive)>, + Results<(outs AnyVector:$dest)> { + let summary = "Scan operation"; + let description = [{ + Performs an inclusive/exclusive scan on an n-D vector along a single + dimension converting it into an n-D vector using the given + operation (add/mul/min/max for int/fp and and/or/xor for + int only) and a specified value for the identity element. The identity + element is only used in the exclusive scan. + + Example: + + ```mlir + %1 = vector.scan "add", %0, %identity [1, 0] {inclusive = false} : + (vector<4x8x16x32xf32>, f32) to vector<4x8x16x32xf32> + %2 = vector.scan "add", %1, %identity [0] {inclusive = true} : + (vector<4x16xf32>, f32) to vector<4x16xf32> + ``` + }]; + let builders = [ + OpBuilder<(ins "Value":$source, "Value":$identity, + "CombiningKind":$kind, + "ArrayAttr":$reduction_dims, + CArg<"bool", "true">:$inclusive)> + ]; + let extraClassDeclaration = [{ + static StringRef getKindAttrName() { return "kind"; } + static StringRef getReductionDimAttrName() { return "reduction_dims"; } + + VectorType getSourceType() { + return source().getType().cast(); + } + VectorType getDestType() { + return dest().getType().cast(); + } + }]; + let assemblyFormat = + "$kind `,` $source `,` $identity $reduction_dims attr-dict `:` `(` type($source) `,` type($identity) `)` `to` type($dest)"; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -4181,6 +4181,34 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// ScanOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ScanOp op) { + VectorType srcType = op.getSourceType(); + VectorType dstType = op.getDestType(); + if (srcType != dstType) + return op.emitError("src and dst types must match"); + + // Check reduction dimensions < rank + int64_t rank = srcType.getRank(); + for (auto i : op.reduction_dims().getAsRange()) { + if (i.getInt() >= rank) + return op.emitOpError("reduction dimension ") + << i << " has to be < " << rank; + } + + // Check identity element is integer or float + auto identityElementType = op.identity().getType(); + if (!(identityElementType.isa() || + identityElementType.isa())) { + return op.emitOpError( + "expected identity element type to be float or integer"); + } + return success(); +} + void mlir::vector::populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns) { patterns diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -717,3 +717,11 @@ %0 = vector.vscale return %0 : index } + +// CHECK-LABEL: @vector_scan +func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> { + %1 = arith.constant 0.0 : f32 + %2 = vector.scan , %0, %1 [1, 2] {inclusive = true} : + (vector<4x8x16x32xf32>, f32) to vector<4x8x16x32xf32> + return %2 : vector<4x8x16x32xf32> +}