diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -832,27 +832,37 @@ AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); - Operation *inconsistentOp = nullptr; - WalkResult walkResult = op->walk([&](Operation *op) { - if (auto bufferizableOp = options.dynCastBufferizableOp(op)) - for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) { - if (wouldCreateReadAfterWriteInterference( - opOperand, domInfo, state, aliasInfo, - /*checkConsistencyOnly=*/true)) { - // This error can happen if certain "mustBufferizeInPlace" interface - // methods are implemented incorrectly, such that the IR already has - // a RaW conflict before making any bufferization decisions. - inconsistentOp = op; - return WalkResult::interrupt(); - } + + WalkResult walkResult = op->walk([&](BufferizableOpInterface op) { + // Skip ops that are not in the filter. + if (!options.isOpAllowed(op.getOperation())) + return WalkResult::advance(); + + // Input IR may not contain any ToMemrefOps. These are not supported because + // the analysis cannot follow the data flow through memrefs. + if (isa(op.getOperation())) { + op->emitError("to_memref ops not supported during One-Shot Analysis"); + return WalkResult::interrupt(); + } + + for (OpOperand &opOperand : op->getOpOperands()) { + if (opOperand.get().getType().isa()) { + if (wouldCreateReadAfterWriteInterference( + opOperand, domInfo, state, aliasInfo, + /*checkConsistencyOnly=*/true)) { + // This error can happen if certain "mustBufferizeInPlace" interface + // methods are implemented incorrectly, such that the IR already has + // a RaW conflict before making any bufferization decisions. + op->emitError("input IR has RaW conflict"); + return WalkResult::interrupt(); } + } + } + return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) - return inconsistentOp->emitError("input IR has RaW conflict"); - return success(); + return success(!walkResult.wasInterrupted()); } /// Annotate the IR with the result of the analysis. For testing/debugging only. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1074,28 +1074,6 @@ // ----- -// CHECK-LABEL: func @to_memref_op_is_reading -func.func @to_memref_op_is_reading(%t1: tensor {bufferization.writable = true}, - %idx1: index, %idx2: index, %idx3: index, - %v1: vector<5xf32>) - -> (vector<5xf32>, vector<5xf32>) { - // Write + read to/from tensor. - // CHECK: vector.transfer_write - // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"] - %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor - %cst = arith.constant 0.0 : f32 - %r1 = vector.transfer_read %1[%idx3], %cst : tensor, vector<5xf32> - - // Write + read to/from same memref. - %0 = bufferization.to_memref %t1 : memref - vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref - %r2 = vector.transfer_read %0[%idx3], %cst : memref, vector<5xf32> - - return %r1, %r2 : vector<5xf32>, vector<5xf32> -} - -// ----- - // CHECK-LABEL: func @inner_func func.func @inner_func(%t: tensor) -> tensor { // CHECK: return diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -249,7 +249,7 @@ // read further down. This will likely have to change with partial // bufferization. - // expected-error @+1 {{input IR has RaW conflict}} + // expected-error @+1 {{to_memref ops not supported during One-Shot Analysis}} %0 = bufferization.to_memref %t1 : memref // Read from both. @@ -289,7 +289,7 @@ // ----- func.func @destination_passing_style_dominance_test_1(%cst : f32, %idx : index, - %idx2 : index) -> f32 { + %idx2 : index) -> f32 { %0 = scf.execute_region -> tensor { %1 = bufferization.alloc_tensor(%idx) : tensor // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} @@ -303,7 +303,7 @@ // ----- func.func @destination_passing_style_dominance_test_2(%cst : f32, %idx : index, - %idx2 : index) -> f32 { + %idx2 : index) -> f32 { %1 = bufferization.alloc_tensor(%idx) : tensor %0 = scf.execute_region -> tensor { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir @@ -166,8 +166,7 @@ %du = arith.constant -1.0 : f64 %c = sparse_tensor.convert %A : tensor<9x4xf64, #MAT_C_C> to tensor<9x4xf64> - %m = bufferization.to_memref %c : memref<9x4xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64> + %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C> to memref @@ -182,8 +181,7 @@ %du = arith.constant -1.0 : f64 %c = sparse_tensor.convert %A : tensor<9x4xf64, #MAT_C_C_P> to tensor<9x4xf64> - %m = bufferization.to_memref %c : memref<9x4xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64> + %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C_P> to memref @@ -197,8 +195,7 @@ %c0 = arith.constant 0 : index %du = arith.constant -1.0 : f64 - %m = bufferization.to_memref %A : memref<9x4xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<9x4xf64>, vector<9x4xf64> + %v = vector.transfer_read %A[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> return @@ -209,8 +206,7 @@ %du = arith.constant -1.0 : f64 %c = sparse_tensor.convert %A : tensor<4x9xf64, #MAT_C_C> to tensor<4x9xf64> - %m = bufferization.to_memref %c : memref<4x9xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64> + %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C> to memref @@ -225,8 +221,7 @@ %du = arith.constant -1.0 : f64 %c = sparse_tensor.convert %A : tensor to tensor - %m = bufferization.to_memref %c : memref - %v = vector.transfer_read %m[%c0, %c0], %du: memref, vector<4x9xf64> + %v = vector.transfer_read %c[%c0, %c0], %du: tensor, vector<4x9xf64> vector.print %v : vector<4x9xf64> %1 = sparse_tensor.values %A : tensor to memref @@ -241,8 +236,7 @@ %du = arith.constant -1.0 : f64 %c = sparse_tensor.convert %A : tensor<4x9xf64, #MAT_C_C_P> to tensor<4x9xf64> - %m = bufferization.to_memref %c : memref<4x9xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64> + %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C_P> to memref @@ -256,8 +250,7 @@ %c0 = arith.constant 0 : index %du = arith.constant -1.0 : f64 - %m = bufferization.to_memref %A : memref<4x9xf64> - %v = vector.transfer_read %m[%c0, %c0], %du: memref<4x9xf64>, vector<4x9xf64> + %v = vector.transfer_read %A[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> return