diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp @@ -140,6 +140,15 @@ return WalkResult::skip(); Value allocTensor = maybeAllocTensor.front(); + // Replace only if the types match. + // TODO: This could be extended to support IR such as: + // %0 = bufferization.alloc_tensor : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + if (allocTensor.getType() != operand.get().getType()) + return WalkResult::skip(); + // Find a suitable insertion point. Operation *insertionPoint = findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir @@ -94,7 +94,7 @@ // CHECK: func @insertion_point_outside_loop( // CHECK-SAME: %[[t:.*]]: memref, %[[sz:.*]]: index, %[[idx:.*]]: index) func.func @insertion_point_outside_loop(%t : tensor, %sz : index, - %idx : index) -> (tensor) { + %idx : index) -> (tensor) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c5 = arith.constant 5 : index @@ -118,3 +118,21 @@ return %r : tensor } + +// ----- + +// AllocTensorElimination does currently not apply to chains where the type is +// changing. This test just ensures that we do not crash or generate IR that +// does not verify. + +// CHECK-LABEL: func @shape_mismatch +func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { + %cst = arith.constant 8.0 : f32 + %0 = bufferization.alloc_tensor() : tensor<128xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32> + %2 = tensor.expand_shape %1 [[0, 1, 2]] + : tensor<128xf32> into tensor<1x1x128xf32> + %3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1] + : tensor<1x1x128xf32> into tensor<5x6x128xf32> + return %3 : tensor<5x6x128xf32> +}