diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1297,8 +1297,9 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(padOp); auto zero = rewriter.create(loc, 0); + auto destType = RankedTensorType::get(inputVectorSizes, padValue.getType()); auto emptyOp = - rewriter.create(loc, padOp.getResultType(), + rewriter.create(loc, destType, /*dynamicSizes=*/ValueRange{}); SmallVector mixedSourceDims = getMixedDimensions(rewriter, loc, padOp.getSource()); @@ -1319,7 +1320,12 @@ /*source=*/emptyOp, /*indices=*/SmallVector(rank, zero), /*inBounds=*/SmallVector(rank, true)); - newResults.push_back(transferWriteOp.getResult()); + Value result = transferWriteOp.getResult(); + if (destType != padOp.getResultType()) { + result = + rewriter.create(loc, padOp.getResultType(), result); + } + newResults.push_back(result); return success(); } @@ -1445,12 +1451,26 @@ } ArrayRef resultTensorShape = padOp.getResultType().getShape(); - if (!(resultTensorShape == inputVectorSizes)) { - LDBG("result tensor shape must match input vector sizes: " << padOp - << "\n"); + if (inputVectorSizes.size() != resultTensorShape.size()) { + LDBG("Input vector sizes don't match the rank of pad op : " << padOp + << "\n"); return failure(); } + assert(!ShapedType::isDynamicShape(inputVectorSizes) && + "Input vector sizes can't have dynamic dimensions"); + + for (auto [dimSize, vecSize] : + llvm::zip_equal(resultTensorShape, inputVectorSizes)) { + if (ShapedType::isDynamic(dimSize)) + continue; + if (vecSize < dimSize) { + LDBG("input vector sizes must be greater than or equal to result sizes: " + << padOp << "\n"); + return failure(); + } + } + if (llvm::any_of(padOp.getLow(), [](Value v) { std::optional res = getConstantIntValue(v); return !res.has_value() || res.value() != 0; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2866,6 +2866,42 @@ // ----- +// CHECK-LABEL: func @test_masked_vectorize_dynamic_pad +func.func @test_masked_vectorize_dynamic_pad( + %0 : tensor, %h0 : index, %h1 : index) + -> tensor +{ + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 + // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> + // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> + // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] + // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> + // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> + // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] + // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> + %cst = arith.constant 42.43 : f32 + %c0 = arith.constant 0 : index + %1 = tensor.pad %0 low[0, %c0] high[%h0, %h1] { + ^bb0(%hh1: index, %hh2: index): + tensor.yield %cst : f32 + } : tensor to tensor + return %1: tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [2, 4] +} + +// ----- + // CHECK-LABEL: func @test_masked_pad_static_dynamic func.func @test_masked_pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, %pad_value: f32) -> tensor<6x?x?x?xf32> {