diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -790,9 +791,12 @@ Value srcDim = rewriter.create(loc, padOp.getSource(), i); Value lowPad = toValue(mixedLowPad[i]); Value highPad = toValue(mixedHighPad[i]); - Value s1 = rewriter.create(loc, lowPad, highPad); - Value s2 = rewriter.create(loc, s1, srcDim); - dynamicSizes.push_back(s2); + AffineExpr s0, s1, s2; + bindSymbols(op->getContext(), s0, s1, s2); + AffineExpr sumExpr = s0 + s1 + s2; + Value sum = rewriter.create( + loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); + dynamicSizes.push_back(sum); } // Create tensor::GenerateOp. diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRTensorTransformsIncGen LINK_LIBS PUBLIC + MLIRAffineDialect MLIRArithmeticDialect MLIRBufferizationDialect MLIRBufferizationTransforms diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -547,6 +547,7 @@ // ----- +// CHECK: #[[$sum_map:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> // CHECK-LABEL: func @tensor.pad( // CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, @@ -557,10 +558,8 @@ // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] - // CHECK-DAG: %[[pad0:.*]] = arith.addi %[[c5]], %[[h1]] - // CHECK-DAG: %[[size0:.*]] = arith.addi %[[pad0]], %[[dim0]] - // CHECK-DAG: %[[pad1:.*]] = arith.addi %[[l2]], %[[h2]] - // CHECK-DAG: %[[size1:.*]] = arith.addi %[[pad1]], %[[dim1]] + // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]] + // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref // CHECK: scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) { // CHECK: memref.store diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5048,6 +5048,7 @@ ], includes = ["include"], deps = [ + ":AffineDialect", ":ArithmeticDialect", ":BufferizationDialect", ":BufferizationTransforms",