diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2429,4 +2429,65 @@ let verifier = ?; } +//===----------------------------------------------------------------------===// +// VectorScanOp +//===----------------------------------------------------------------------===// + +def Vector_ScanOp : + Vector_Op<"scan", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins Vector_CombiningKindAttr:$kind, + AnyVector:$source, + AnyVectorOfAnyRank:$initial_value, + I64Attr:$reduction_dim, + BoolAttr:$inclusive)>, + Results<(outs AnyVector:$dest, + AnyVectorOfAnyRank:$accumulated_value)> { + let summary = "Scan operation"; + let description = [{ + Performs an inclusive/exclusive scan on an n-D vector along a single + dimension returning an n-D result vector using the given + operation (add/mul/min/max for int/fp and and/or/xor for + int only) and a specified value for the initial value. The operator + returns the result of scan as well as the result of the last + reduction in the scan. + + Example: + + ```mlir + %1:2 = vector.scan , %0, %acc {inclusive = false, reduction_dim = 1 : i64} : + (vector<4x8x16x32xf32>, vector<4x16x32xf32>) to (vector<4x8x16x32xf32>, vector<4x16x32xf32>) + %2:2 = vector.scan , %1, %acc {inclusive = true, reduction_dim = 0 : i64} : + (vector<4x16xf32>, vector<16xf32>) to (vector<4x16xf32>, vector<16xf32>) + ``` + }]; + let builders = [ + OpBuilder<(ins "Value":$source, "Value":$initial_value, + "CombiningKind":$kind, + CArg<"int64_t", "0">:$reduction_dim, + CArg<"bool", "true">:$inclusive)> + ]; + let extraClassDeclaration = [{ + static StringRef getKindAttrName() { return "kind"; } + static StringRef getReductionDimAttrName() { return "reduction_dim"; } + VectorType getSourceType() { + return source().getType().cast(); + } + VectorType getDestType() { + return dest().getType().cast(); + } + VectorType getAccumulatorType() { + return accumulated_value().getType().cast(); + } + VectorType getInitialValueType() { + return initial_value().getType().cast(); + } + }]; + let assemblyFormat = + "$kind `,` $source `,` $initial_value attr-dict `:` " + "`(` type($source) `,` type($initial_value) `)` `to` " + "`(` type($dest) `,` type($accumulated_value) `)` "; +} + #endif // VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h @@ -171,6 +171,9 @@ /// transpose/broadcast ops into the contract. void populateVectorReductionToContractPatterns(RewritePatternSet &patterns); +/// Collect patterns to convert scan op +void populateVectorScanLoweringPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Vector.transfer patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -4263,6 +4263,52 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// ScanOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ScanOp op) { + VectorType srcType = op.getSourceType(); + VectorType dstType = op.getDestType(); + if (srcType != dstType) + return op.emitError("src and dst types must match"); + + VectorType initialType = op.getInitialValueType(); + VectorType accType = op.getAccumulatorType(); + if (initialType != accType) + return op.emitError("initial value and accumulator types mush match"); + + // Check reduction dimension < rank. + int64_t srcRank = srcType.getRank(); + int64_t reductionDim = op.reduction_dim(); + if (reductionDim >= srcRank) + return op.emitOpError("reduction dimension ") + << reductionDim << " has to be < " << srcRank; + + // Check that rank(initial_value) = rank(src) - 1. + int64_t initialValueRank = initialType.getRank(); + if (initialValueRank != srcRank - 1) + return op.emitOpError("initial value rank ") + << initialValueRank << " has to be equal to " << srcRank - 1; + + // Check shapes of initial value and src. + ArrayRef srcShape = srcType.getShape(); + ArrayRef initialValueShapes = initialType.getShape(); + SmallVector expectedShape; + for (int i = 0; i < srcRank; i++) { + if (i != reductionDim) + expectedShape.push_back(srcShape[i]); + } + if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape), + [](std::tuple s) { + return std::get<0>(s) != std::get<1>(s); + })) { + return op.emitOpError("incompatible input/initial value shapes"); + } + + return success(); +} + void mlir::vector::populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns) { patterns diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2348,6 +2348,188 @@ } }; +namespace { + +/// Convert vector.scan op into arith ops and +/// vector.insert_strided_slice/extract_strided_slice +/// +/// Ex: +/// ``` +/// %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = +/// 1} : +/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) +/// ``` +/// Gets converted to: +/// ``` +/// %cst = arith.constant dense<0> : vector<2x3xi32> +/// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], +/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = +/// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} +/// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice +/// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : +/// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : +/// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], +/// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = +/// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], +/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, +/// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = +/// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = +/// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : +/// vector<2x3xi32>, vector<2xi32> +/// ``` +struct ScanToArithOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static Optional genOperatorI(Location loc, Value x, Value y, + vector::CombiningKind kind, + PatternRewriter &rewriter) { + using vector::CombiningKind; + + Value combinedResult; + switch (kind) { + case CombiningKind::ADD: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINUI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINSI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXUI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXSI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::AND: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::OR: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::XOR: + combinedResult = rewriter.create(loc, x, y); + break; + default: + return Optional(); + } + return Optional(combinedResult); + } + + static Optional genOperatorF(Location loc, Value x, Value y, + vector::CombiningKind kind, + PatternRewriter &rewriter) { + using vector::CombiningKind; + Value combinedResult; + switch (kind) { + case CombiningKind::ADD: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MUL: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINF: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXF: + combinedResult = rewriter.create(loc, x, y); + break; + default: + return Optional(); + } + return Optional(combinedResult); + } + + LogicalResult matchAndRewrite(vector::ScanOp scanOp, + PatternRewriter &rewriter) const override { + auto loc = scanOp.getLoc(); + VectorType destType = scanOp.getDestType(); + ArrayRef destShape = destType.getShape(); + auto elType = destType.getElementType(); + bool isInt = elType.isa(); + VectorType resType = VectorType::get(destShape, elType); + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + int64_t reductionDim = scanOp.reduction_dim(); + bool inclusive = scanOp.inclusive(); + int64_t destRank = destType.getRank(); + VectorType initialValueType = scanOp.getInitialValueType(); + int64_t initialValueRank = initialValueType.getRank(); + + SmallVector reductionShape(destShape.begin(), destShape.end()); + reductionShape[reductionDim] = 1; + VectorType reductionType = VectorType::get(reductionShape, elType); + SmallVector offsets, sizes, strides; + for (int i = 0; i < destRank; i++) { + offsets.push_back(0); + strides.push_back(1); + if (i == reductionDim) { + sizes.push_back(1); + } else { + sizes.push_back(destShape[i]); + } + } + ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); + ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); + + Value lastOutput, lastInput; + for (int i = 0; i < destShape[reductionDim]; i++) { + offsets[reductionDim] = i; + ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); + Value input = rewriter.create( + loc, reductionType, scanOp.source(), scanOffsets, scanSizes, + scanStrides); + Value output; + if (i == 0) { + if (inclusive) { + output = input; + } else { + if (initialValueRank == 0) { + // ShapeCastOp cannot handle 0-D vectors + output = rewriter.create( + loc, input.getType(), scanOp.initial_value()); + } else { + output = rewriter.create( + loc, input.getType(), scanOp.initial_value()); + } + } + } else { + Value y = inclusive ? input : lastInput; + Optional m; + m = isInt ? genOperatorI(loc, lastOutput, y, scanOp.kind(), rewriter) + : genOperatorF(loc, lastOutput, y, scanOp.kind(), rewriter); + if (!m.hasValue()) + return failure(); + output = m.getValue(); + } + result = rewriter.create( + loc, output, result, offsets, strides); + lastOutput = output; + lastInput = input; + } + + Value reduction; + if (initialValueRank == 0) { + Value v = rewriter.create(loc, lastOutput, 0); + reduction = + rewriter.create(loc, initialValueType, v); + } else { + reduction = rewriter.create(loc, initialValueType, + lastOutput); + } + + scanOp->getBlock()->dump(); + rewriter.replaceOp(scanOp, {result, reduction}); + return success(); + } +}; + +} // namespace + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool indexOptimizations) { patterns.add( patterns.getContext()); } + +void mlir::vector::populateVectorScanLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1480,3 +1480,47 @@ %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32> } +// ----- + +func @scan_src_dst_type_mismatch(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3x4xi32> { + // expected-error@+1 {{src and dst types must match}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2x3xi32>, vector<3xi32>) to (vector<3x4xi32>, vector<3xi32>) + return %0#0 : vector<3x4xi32> +} + +// ----- + +func @scan_ival_acc_type_mismatch(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<4xi32> { + // expected-error@+1 {{initial value and accumulator types mush match}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2x3xi32>, vector<3xi32>) to (vector<2x3xi32>, vector<4xi32>) + return %0#1 : vector<4xi32> +} + +// ----- + +func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> { + // expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be < 2}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 5} : + (vector<2x3xi32>, vector<3xi32>) to (vector<2x3xi32>, vector<3xi32>) + return %0#1 : vector<3xi32> +} + +// ----- + +func @scan_ival_rank_constraint(%arg0: vector<2x3xi32>, %arg1: vector<1x3xi32>) -> vector<1x3xi32> { + // expected-error@+1 {{initial value rank 2 has to be equal to 1}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2x3xi32>, vector<1x3xi32>) to (vector<2x3xi32>, vector<1x3xi32>) + return %0#1 : vector<1x3xi32> +} + +// ----- + +func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) -> vector<2x3xi32> { + // expected-error@+1 {{incompatible input/initial value shapes}} + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2x3xi32>, vector<5xi32>) to (vector<2x3xi32>, vector<5xi32>) + return %0#0 : vector<2x3xi32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -717,3 +717,11 @@ %0 = vector.vscale return %0 : index } + +// CHECK-LABEL: @vector_scan +func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> { + %1 = arith.constant dense<0.0> : vector<4x16x32xf32> + %2:2 = vector.scan , %0, %1 {reduction_dim = 1 : i64, inclusive = true} : + (vector<4x8x16x32xf32>, vector<4x16x32xf32>) to (vector<4x8x16x32xf32>, vector<4x16x32xf32>) + return %2#0 : vector<4x8x16x32xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir @@ -0,0 +1,91 @@ +// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s + +// CHECK-LABEL: func @scan1d_inc +// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32> +// CHECK: %[[E:.*]] = arith.addi %[[B]], %[[D]] : vector<1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32> +// CHECK: %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32> +// CHECK: %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector +// CHECK: return %[[F]], %[[H]] : vector<2xi32>, vector +func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector) -> (vector<2xi32>, vector) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2xi32>, vector) to (vector<2xi32>, vector) + return %0#0, %0#1 : vector<2xi32>, vector +} + +// CHECK-LABEL: func @scan1d_exc +// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32> +// CHECK: %[[C:.*]] = vector.broadcast %[[ARG1]] : vector to vector<1xi32> +// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<2xi32> +// CHECK: %[[E:.*]] = arith.addi %[[C]], %[[B]] : vector<1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<2xi32> +// CHECK: %[[G:.*]] = vector.extract %[[E]][0] : vector<1xi32> +// CHECK: %[[H:.*]] = vector.broadcast %[[G]] : i32 to vector +// CHECK: return %[[F]], %[[H]] : vector<2xi32>, vector +func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector) -> (vector<2xi32>, vector) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = false, reduction_dim = 0} : + (vector<2xi32>, vector) to (vector<2xi32>, vector) + return %0#0, %0#1 : vector<2xi32>, vector +} + +// CHECK-LABEL: func @scan2d_mul_dim0 +// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<3xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x3xi32> to vector<1x3xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x3xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x3xi32> into vector<2x3xi32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x3xi32> to vector<3xi32> +// CHECK: return %[[F]], %[[G]] : vector<2x3xi32>, vector<3xi32> +func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (vector<2x3xi32>, vector<3xi32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + (vector<2x3xi32>, vector<3xi32>) to (vector<2x3xi32>, vector<3xi32>) + return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32> +} + +// CHECK-LABEL: func @scan2d_mul_dim1 +// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<2xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<2x1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> +// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> +// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<2x1xi32> +// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> +// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<2x1xi32> to vector<2xi32> +// CHECK: return %[[I]], %[[J]] : vector<2x3xi32>, vector<2xi32> +func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (vector<2x3xi32>, vector<2xi32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 1} : + (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) + return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32> +} + +// CHECK-LABEL: func @scan3d_mul_dim1 +// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32> +// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x3xf32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32> +// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x3xf32> to vector<4x1x3xf32> +// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x3xf32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x3xf32> to vector<4x3xf32> +// CHECK: return %[[F]], %[[G]] : vector<4x2x3xf32>, vector<4x3xf32> +func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) -> (vector<4x2x3xf32>, vector<4x3xf32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = false, reduction_dim = 1} : + (vector<4x2x3xf32>, vector<4x3xf32>) to (vector<4x2x3xf32>, vector<4x3xf32>) + return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scan.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -test-vector-scan-lowering -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() { + %f1 = arith.constant 1.0: f32 + %f2 = arith.constant 2.0: f32 + %f3 = arith.constant 3.0: f32 + %f4 = arith.constant 4.0: f32 + %f5 = arith.constant 5.0: f32 + %f6 = arith.constant 6.0: f32 + + // Construct test vector. + %0 = vector.broadcast %f1 : f32 to vector<3x2xf32> + %1 = vector.insert %f2, %0[0, 1] : f32 into vector<3x2xf32> + %2 = vector.insert %f3, %1[1, 0] : f32 into vector<3x2xf32> + %3 = vector.insert %f4, %2[1, 1] : f32 into vector<3x2xf32> + %4 = vector.insert %f5, %3[2, 0] : f32 into vector<3x2xf32> + %x = vector.insert %f6, %4[2, 1] : f32 into vector<3x2xf32> + vector.print %x : vector<3x2xf32> + // CHECK: ( ( 1, 2 ), ( 3, 4 ), ( 5, 6 ) ) + + %y = vector.broadcast %f6 : f32 to vector<2xf32> + %z = vector.broadcast %f6 : f32 to vector<3xf32> + // Scan + %a:2 = vector.scan , %x, %y {inclusive = true, reduction_dim = 0} : + (vector<3x2xf32>, vector<2xf32>) to (vector<3x2xf32>, vector<2xf32>) + %b:2 = vector.scan , %x, %z {inclusive = true, reduction_dim = 1} : + (vector<3x2xf32>, vector<3xf32>) to (vector<3x2xf32>, vector<3xf32>) + %c:2 = vector.scan , %x, %y {inclusive = false, reduction_dim = 0} : + (vector<3x2xf32>, vector<2xf32>) to (vector<3x2xf32>, vector<2xf32>) + %d:2 = vector.scan , %x, %z {inclusive = false, reduction_dim = 1} : + (vector<3x2xf32>, vector<3xf32>) to (vector<3x2xf32>, vector<3xf32>) + + // CHECK: ( ( 1, 2 ), ( 4, 6 ), ( 9, 12 ) ) + // CHECK: ( 9, 12 ) + // CHECK: ( ( 1, 3 ), ( 3, 7 ), ( 5, 11 ) ) + // CHECK: ( 3, 7, 11 ) + // CHECK: ( ( 6, 6 ), ( 7, 8 ), ( 10, 12 ) ) + // CHECK: ( 10, 12 ) + // CHECK: ( ( 6, 7 ), ( 6, 9 ), ( 6, 11 ) ) + // CHECK: ( 7, 9, 11 ) + vector.print %a#0 : vector<3x2xf32> + vector.print %a#1 : vector<2xf32> + vector.print %b#0 : vector<3x2xf32> + vector.print %b#1 : vector<3xf32> + vector.print %c#0 : vector<3x2xf32> + vector.print %c#1 : vector<2xf32> + vector.print %d#0 : vector<3x2xf32> + vector.print %d#1 : vector<3xf32> + + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -627,6 +627,20 @@ } }; +struct TestVectorScanLowering + : public PassWrapper> { + StringRef getArgument() const final { return "test-vector-scan-lowering"; } + StringRef getDescription() const final { + return "Test lowering patterns that lower the scan op in the vector " + "dialect"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorScanLoweringPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -661,6 +675,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir