diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -322,9 +322,10 @@ // Mark all Standard operations legal. target.addLegalDialect(); + memref::MemRefDialect, StandardOpsDialect, + tensor::TensorDialect>(); target.addIllegalOp(); + tensor::InsertSliceOp, PadTensorOp>(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -265,3 +265,37 @@ // CHECK-SAME: : memref<4x5xf32> into memref<20xf32> // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> // CHECK: return %[[TENSOR]] + +// ----- + +// CHECK-LABEL: func @pad_tensor_dynamic_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<4x?x2x?xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> { +func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %out = linalg.pad_tensor %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32> + return %out : tensor<4x?x?x?xf32> +} + +// CHECK: %[[C3:.*]] = constant 3 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM2:.*]] = addi %[[OFFSET]], %[[C2]] : index +// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> +// CHECK: %[[OUT_DIM3:.*]] = addi %[[DIM3]], %[[OFFSET]] : index +// CHECK: %[[FILLED:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32> +// CHECK: linalg.fill(%[[CST]], %[[FILLED]]) : f32, memref<4x?x?x?xf32> +// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32> +// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32> +// CHECK: linalg.copy(%[[FILLED]], %[[OUT]]) : memref<4x?x?x?xf32>, memref<4x?x?x?xf32> +// CHECK: %[[INTERIOR:.*]] = memref.subview %[[OUT]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : memref<4x?x?x?xf32> to memref<4x?x2x?xf32, #map> +// CHECK: linalg.copy(%[[IN_MEMREF]], %[[INTERIOR]]) : memref<4x?x2x?xf32>, memref<4x?x2x?xf32, #map> +// CHECK: %[[OUT_TENSOR:.*]] = memref.tensor_load %[[OUT]] : memref<4x?x?x?xf32> +// CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32> +// CHECK: }