diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -322,9 +322,10 @@ // Mark all Standard operations legal. target.addLegalDialect(); + memref::MemRefDialect, StandardOpsDialect, + tensor::TensorDialect>(); target.addIllegalOp(); + tensor::InsertSliceOp, PadTensorOp>(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor-dynamic.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor-dynamic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor-dynamic.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %const = constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]]> : tensor<1x2x3xf32> + %dynamic = tensor.cast %const: tensor<1x2x3xf32> to tensor<1x?x3xf32> + %offset = constant 2 : index + %padded = call @pad_tensor(%dynamic, %offset) : (tensor<1x?x3xf32>, index) -> (tensor) + %unranked = tensor.cast %padded: tensor to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 3 offset = 0 sizes = [1, 4, 5] strides = [20, 5, 1] data = + // CHECK-NEXT{LITERAL}: [[[2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [2.3, 2.3, 2.3, 2.3, 2.3], + // CHECK-NEXT: [1, 2, 3, 2.3, 2.3], + // CHECK-NEXT: [2, 3, 4, 2.3, 2.3]]] + + return +} + +func @pad_tensor(%arg0: tensor<1x?x3xf32>, %arg1: index) -> tensor { + %cst = constant 2.3 : f32 + %c0 = constant 0 : index + %out = linalg.pad_tensor %arg0 low[%c0, %arg1, %c0] high[%c0, %c0, %arg1] { + ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<1x?x3xf32> to tensor<1x?x?xf32> + %dynamic = tensor.cast %out: tensor<1x?x?xf32> to tensor + return %dynamic: tensor +} + +func private @print_memref_f32(%ptr : tensor<*xf32>)