diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -408,6 +408,10 @@ /// Specifies whether unknown/non-bufferizable/ops not included in the /// OpFilter of BufferizationOptions should be followed. bool followUnknownOps = false; + + /// Specifies whether OpOperands with a different type that are not the result + /// of a CastOpInterface op should be followed. + bool followSameTypeOrCastsOnly = false; }; /// AnalysisState provides a variety of helper functions for dealing with diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -580,6 +580,16 @@ continue; } + if (config.followSameTypeOrCastsOnly && + a.opOperand->get().getType() != value.getType() && + !opResult.getDefiningOp()) { + // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is + // has a different type and the op is not a cast. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } + workingSet.insert(a.opOperand->get()); } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -135,6 +135,14 @@ TraversalConfig config; config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; + // Replace only if the types match or are static <-> dynamic casts. We do + // not support slices or reshapes. + // TODO: This could be extended to support IR such as: + // %0 = tensor.empty() : tensor<128xf32> + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) + // %2 = tensor.expand_shape %1 ... + // %3 = tensor.insert_slice %2 into ... + config.followSameTypeOrCastsOnly = true; SetVector emptyTensors = state.findValueInReverseUseDefChain( operand.get(), /*condition=*/ [&](Value val) { return val.getDefiningOp(); }, @@ -143,15 +151,6 @@ for (Value v : emptyTensors) { Operation *emptyTensorOp = v.getDefiningOp(); - // Replace only if the types match. We do not support slices or casts. - // TODO: This could be extended to support IR such as: - // %0 = tensor.empty() : tensor<128xf32> - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) - // %2 = tensor.expand_shape %1 ... - // %3 = tensor.insert_slice %2 into ... - if (v.getType() != operand.get().getType()) - continue; - // Find a suitable insertion point. If no suitable insertion point for // the replacement can be found, skip this replacement. Operation *insertionPoint = @@ -164,7 +163,11 @@ rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); if (!replacement) continue; - + if (replacement.getType() != v.getType()) { + rewriter.setInsertionPointAfterValue(replacement); + replacement = rewriter.create(v.getLoc(), v.getType(), + replacement); + } // Replace the tensor::EmptyOp. rewriter.replaceOp(emptyTensorOp, replacement); state.resetCache(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -123,8 +123,8 @@ // ----- // EmptyTensorElimination does currently not apply to chains where the type is -// changing. This test just ensures that we do not crash or generate IR that -// does not verify. +// changing. (Casts are supported.) This test just ensures that we do not crash +// or generate IR that does not verify. // CHECK-LABEL: func @shape_mismatch func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { @@ -140,6 +140,24 @@ // ----- +// CHECK-LABEL: func @cast( +// CHECK-SAME: %[[t:.*]]: memref<256xf32, +// CHECK: %[[sv:.*]] = memref.subview %[[t]] +// CHECK: linalg.fill {{.*}} outs(%[[sv]] +// CHECK: return %[[t]] +func.func @cast(%t: tensor<256xf32>) -> tensor<256xf32> { + %cst = arith.constant 8.0 : f32 + %c128 = arith.constant 128 : index + %0 = tensor.empty(%c128) : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = tensor.cast %1 : tensor to tensor<128xf32> + %3 = tensor.insert_slice %2 into %t[2][128][1] + : tensor<128xf32> into tensor<256xf32> + return %3 : tensor<256xf32> +} + +// ----- + // CHECK: func @parallel_insert_slice( // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index