diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1296,14 +1296,19 @@ // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(padOp); - auto zero = rewriter.create(loc, 0); - auto emptyOp = - rewriter.create(loc, padOp.getResultType(), - /*dynamicSizes=*/ValueRange{}); + + ReifiedRankedShapedTypeDims reifiedReturnShapes; + LogicalResult status = + cast(padOp.getOperation()) + .reifyResultShapes(rewriter, reifiedReturnShapes); + assert(succeeded(status) && "failed to reify result shapes"); + auto emptyOp = rewriter.create(loc, reifiedReturnShapes[0], + padValue.getType()); SmallVector mixedSourceDims = getMixedDimensions(rewriter, loc, padOp.getSource()); Value mask = rewriter.create(loc, maskType, mixedSourceDims); + auto zero = rewriter.create(loc, 0); auto transferReadOp = rewriter.create( loc, /*vectorType=*/vectorType, @@ -1313,13 +1318,14 @@ /*inBounds=*/SmallVector(rank, true)); auto maskedOp = cast( mlir::vector::maskOperation(rewriter, transferReadOp, mask)); - auto transferWriteOp = rewriter.create( + Operation *write = rewriter.create( loc, /*vector=*/maskedOp->getResult(0), /*source=*/emptyOp, /*indices=*/SmallVector(rank, zero), /*inBounds=*/SmallVector(rank, true)); - newResults.push_back(transferWriteOp.getResult()); + write = mlir::vector::maskOperation(rewriter, write, mask); + newResults.push_back(write->getResult(0)); return success(); } @@ -1354,6 +1360,31 @@ return success(); } +static LogicalResult +isValidMaskedInputVector(ArrayRef shape, + ArrayRef inputVectorSizes) { + if (inputVectorSizes.size() != shape.size()) { + LDBG("Input vector sizes don't match the number of loops"); + return failure(); + } + if (ShapedType::isDynamicShape(inputVectorSizes)) { + LDBG("Input vector sizes can't have dynamic dimensions"); + return failure(); + } + if (!llvm::all_of(llvm::zip(shape, inputVectorSizes), + [](std::tuple sizePair) { + int64_t staticSize = std::get<0>(sizePair); + int64_t inputSize = std::get<1>(sizePair); + return ShapedType::isDynamic(staticSize) || + staticSize <= inputSize; + })) { + LDBG("Input vector sizes must be greater than or equal to iteration space " + "static sizes"); + return failure(); + } + return success(); +} + static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef inputVectorSizes, @@ -1363,23 +1394,10 @@ [](int64_t dim) { return dim == 0; })) return failure(); // Check API contract for input vector sizes. - if (!inputVectorSizes.empty()) { - assert(inputVectorSizes.size() == linalgOp.getNumLoops() && - "Input vector sizes don't match the number of loops"); - assert(!ShapedType::isDynamicShape(inputVectorSizes) && - "Input vector sizes can't have dynamic dimensions"); - assert( - llvm::all_of( - llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes), - [](std::tuple sizePair) { - int64_t staticSize = std::get<0>(sizePair); - int64_t inputSize = std::get<1>(sizePair); - return ShapedType::isDynamic(staticSize) || - staticSize <= inputSize; - }) && - "Input vector sizes must be greater than or equal to iteration space " - "static sizes"); - } + if (!inputVectorSizes.empty() && + failed(isValidMaskedInputVector(linalgOp.getStaticLoopRanges(), + inputVectorSizes))) + return failure(); if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) { @@ -1445,11 +1463,8 @@ } ArrayRef resultTensorShape = padOp.getResultType().getShape(); - if (!(resultTensorShape == inputVectorSizes)) { - LDBG("result tensor shape must match input vector sizes: " << padOp - << "\n"); + if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes))) return failure(); - } if (llvm::any_of(padOp.getLow(), [](Value v) { std::optional res = getConstantIntValue(v); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2835,13 +2835,13 @@ %0 : tensor, %h0 : index, %h1 : index) -> tensor<2x4xf32> { - // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> + // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> @@ -2866,6 +2866,44 @@ // ----- +// CHECK-LABEL: func @test_masked_vectorize_dynamic_pad +func.func @test_masked_vectorize_dynamic_pad( + %0 : tensor, %h0 : index, %h1 : index) + -> tensor +{ + // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[empty:.*]] = tensor.empty({{.+}}) : tensor + // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> + // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index + // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] + // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> + // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> + // CHECK: %[[masked_write:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] + // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor + // CHECK: return %[[masked_write]] : tensor + %cst = arith.constant 42.43 : f32 + %c0 = arith.constant 0 : index + %1 = tensor.pad %0 low[0, %c0] high[%h0, %h1] { + ^bb0(%hh1: index, %hh2: index): + tensor.yield %cst : f32 + } : tensor to tensor + return %1: tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [2, 4] +} + +// ----- + // CHECK-LABEL: func @test_masked_pad_static_dynamic func.func @test_masked_pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, %pad_value: f32) -> tensor<6x?x?x?xf32> {