diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1423,8 +1423,10 @@ return rewriter.notifyMatchFailure( padOp, "result tensor shape must match input vector sizes"); } - if (llvm::any_of(padOp.getStaticLow(), - [](int64_t val) { return val != 0; })) { + if (llvm::any_of(padOp.getLow(), [](Value v) { + std::optional res = getConstantIntValue(v); + return !res.has_value() || res.value() != 0; + })) { LDBG("low pad must all be zero: " << padOp << "\n"); return rewriter.notifyMatchFailure(padOp, "low pad must all be zero"); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2836,19 +2836,21 @@ -> tensor<2x4xf32> { // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { - // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]] // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> - // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]] + // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> %cst = arith.constant 42.43 : f32 - %1 = tensor.pad %0 low[0, 0] high[%h0, %h1] { + %c0 = arith.constant 0 : index + %1 = tensor.pad %0 low[0, %c0] high[%h0, %h1] { ^bb0(%hh1: index, %hh2: index): tensor.yield %cst : f32 } : tensor to tensor<2x4xf32> @@ -2864,6 +2866,27 @@ // ----- +// CHECK-LABEL: func @test_masked_pad_static_dynamic +func.func @test_masked_pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, + %pad_value: f32) -> tensor<6x?x?x?xf32> { + // CHECK: tensor.pad + %0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %pad_value : f32 + } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> + return %0 : tensor<6x?x?x?xf32> +} + + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_padding } +} + +// ----- + func.func @vectorize_dynamic_matmul(%A: memref, %B: memref, %C: memref) { linalg.matmul ins(%A, %B: memref, memref) outs(%C: memref)