diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -271,7 +271,7 @@ /// Helper function for loop bufferization. Return the indices of all /// bbArg/yielded value pairs who's buffer relation is "Equivalent". -DenseSet getEquivalentBuffers(ValueRange bbArgs, +DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { DenseSet result; @@ -403,6 +403,18 @@ }); } +/// Helper function for loop bufferization. Given a list of bbArgs of the new +/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into +/// ToTensorOps, so that the block body can be moved over to the new op. +SmallVector +getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, + const DenseSet &tensorIndices) { + return convertTensorValues( + bbArgs, tensorIndices, [&](Value val, int64_t index) { + return rewriter.create(val.getLoc(), val); + }); +} + /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface @@ -486,10 +498,8 @@ // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); - SmallVector iterArgs = convertTensorValues( - newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) { - return rewriter.create(val.getLoc(), val); - }); + SmallVector iterArgs = + getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Erase terminator if present. @@ -546,6 +556,187 @@ } }; +/// Bufferization of scf.while. Replace with a new scf.while that operates on +/// memrefs. +struct WhileOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Tensor iter_args of scf::WhileOps are always considered as a read. + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Tensor iter_args of scf::WhileOps are always considered as a write. + return true; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto whileOp = cast(op); + return {whileOp->getResult(opOperand.getOperandNumber())}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // WhileOp results are equivalent to their corresponding init_args if the + // corresponding iter_args and yield values are equivalent (for both the + // "before" and the "after" block). + unsigned int resultNumber = opResult.getResultNumber(); + auto whileOp = cast(op); + + auto conditionOp = whileOp.getConditionOp(); + BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; + Value conditionOperand = conditionOp.getArgs()[resultNumber]; + bool equivCondition = + state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); + + auto yieldOp = whileOp.getYieldOp(); + BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; + Value yieldOperand = yieldOp.getOperand(resultNumber); + bool equivYield = + state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); + + return equivCondition && equivYield ? BufferRelation::Equivalent + : BufferRelation::None; + } + + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + // Interestingly, scf::WhileOp's bbArg can **always** be viewed + // inplace from the perspective of ops nested under: + // 1. Either the matching iter operand is not bufferized inplace and an + // alloc + optional copy makes the bbArg itself inplaceable. + // 2. Or the matching iter operand is bufferized inplace and bbArg just + // bufferizes to that too. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + auto whileOp = cast(op); + + assert(whileOp.getBefore().getBlocks().size() == 1 && + "regions with multiple blocks not supported"); + Block *beforeBody = &whileOp.getBefore().front(); + assert(whileOp.getAfter().getBlocks().size() == 1 && + "regions with multiple blocks not supported"); + Block *afterBody = &whileOp.getAfter().front(); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices = getTensorIndices(whileOp.getInits()); + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYieldsBefore = getEquivalentBuffers( + whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), + state.getAnalysisState()); + DenseSet equivalentYieldsAfter = getEquivalentBuffers( + whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), + state.getAnalysisState()); + + // The new memref init_args of the loop. + SmallVector initArgs = + getBuffers(rewriter, whileOp->getOpOperands(), state); + if (initArgs.size() != indices.size()) + return failure(); + + // Construct a new scf.while op with memref instead of tensor values. + ValueRange argsRange(initArgs); + TypeRange argsTypes(argsRange); + auto newWhileOp = + rewriter.create(whileOp.getLoc(), argsTypes, initArgs); + // Add before/after regions to the new op. + SmallVector bbArgLocs(initArgs.size(), whileOp.getLoc()); + Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); + newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); + Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); + newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); + + // Set up new iter_args and move the loop condition block to the new op. + // The old block uses tensors, so wrap the (memref) bbArgs of the new block + // in ToTensorOps. + rewriter.setInsertionPointToStart(newBeforeBody); + SmallVector newBeforeArgs = getBbArgReplacements( + rewriter, newWhileOp.getBeforeArguments(), indices); + rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); + + // Update scf.condition of new loop. + auto newConditionOp = newWhileOp.getConditionOp(); + rewriter.setInsertionPoint(newConditionOp); + SmallVector newConditionArgs = + getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, + equivalentYieldsBefore, state); + newConditionOp.getArgsMutable().assign(newConditionArgs); + + // Set up new iter_args and move the loop body block to the new op. + // The old block uses tensors, so wrap the (memref) bbArgs of the new block + // in ToTensorOps. + rewriter.setInsertionPointToStart(newAfterBody); + SmallVector newAfterArgs = + getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); + rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); + + // Update scf.yield of the new loop. + auto newYieldOp = newWhileOp.getYieldOp(); + rewriter.setInsertionPoint(newYieldOp); + SmallVector newYieldValues = + getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, + equivalentYieldsAfter, state); + newYieldOp.getResultsMutable().assign(newYieldValues); + + // Replace loop results. + replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); + + return success(); + } + + /// Assert that yielded values of an scf.while op are equivalent to their + /// corresponding bbArgs. In that case, the buffer relations of the + /// corresponding OpResults are "Equivalent". + /// + /// If this is not the case, allocs+copies are inserted and yielded from + /// the loop. This could be a performance problem, so it must be explicitly + /// activated with `alloc-return-allocs`. + /// + /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the + /// equivalence condition must be checked for both. + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto whileOp = cast(op); + const auto &options = + static_cast(state.getOptions()); + if (options.allowReturnAllocs) + return success(); + + auto conditionOp = whileOp.getConditionOp(); + for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + if (!it.value().getType().isa()) + continue; + if (!state.areEquivalentBufferizedValues( + it.value(), conditionOp->getBlock()->getArgument(it.index()))) + return conditionOp->emitError() + << "Condition arg #" << it.index() + << " is not equivalent to the corresponding iter bbArg"; + } + + auto yieldOp = whileOp.getYieldOp(); + for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + if (!it.value().getType().isa()) + continue; + if (!state.areEquivalentBufferizedValues( + it.value(), yieldOp->getBlock()->getArgument(it.index()))) + return yieldOp->emitError() + << "Yield operand #" << it.index() + << " is not equivalent to the corresponding iter bbArg"; + } + + return success(); + } +}; + /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface @@ -581,7 +772,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto yieldOp = cast(op); - if (!isa( + if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); return success(); @@ -598,6 +789,7 @@ ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); + WhileOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -110,6 +110,54 @@ // ----- +func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) -> (i1, i1) +{ + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // expected-error @+1 {{Condition arg #0 is not equivalent to the corresponding iter bbArg}} + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1> + } + + %v0 = tensor.extract %r0[%idx] : tensor<5xi1> + %v1 = tensor.extract %r1[%idx] : tensor<5xi1> + return %v0, %v1 : i1, i1 +} + +// ----- + +func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) -> (i1, i1) +{ + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}} + scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1> + } + + %v0 = tensor.extract %r0[%idx] : tensor<5xi1> + %v1 = tensor.extract %r1[%idx] : tensor<5xi1> + return %v0, %v1 : i1, i1 +} + +// ----- + func.func private @fun_with_side_effects(%A: tensor {bufferization.writable = true}) func.func @foo(%A: tensor {bufferization.writable = true}) -> (tensor) { diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -1,12 +1,12 @@ -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null // Test bufferization using memref types that have no layout map. -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -328,3 +328,124 @@ // CHECK: return %[[r0]], %[[r1]] return %f0, %f1: f32, f32 } + +// ----- + +// CHECK-LABEL: func @scf_while( +// CHECK-SAME: %[[arg0:.*]]: memref +func.func @scf_while(%arg0: tensor, %idx: index) -> tensor { + // CHECK: scf.while : () -> () { + %res = scf.while (%arg1 = %arg0) : (tensor) -> tensor { + // CHECK: %[[condition:.*]] = memref.load %[[arg0]] + // CHECK: scf.condition(%[[condition]]) + %condition = tensor.extract %arg1[%idx] : tensor + scf.condition(%condition) %arg1 : tensor + } do { + ^bb0(%arg2: tensor): + // CHECK: } do { + // CHECK: memref.store %{{.*}}, %[[arg0]] + // CHECK: scf.yield + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %arg2[%pos] : tensor + scf.yield %1 : tensor + } + + // CHECK: return + return %res : tensor +} + +// ----- + +// The loop condition yields non-equivalent buffers. + +// CHECK-LABEL: func @scf_while_non_equiv_condition( +// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> +func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // These allocation used to be inside the scf.while loop, but they were + // hoisted. + // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = memref.load %[[w0]] + // CHECK: memref.copy %[[w1]], %[[a1]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] + // CHECK: memref.copy %[[w0]], %[[a0]] + // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] + // CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: memref.store %{{.*}}, %[[b0]] + // CHECK: scf.yield %[[b0]], %[[b1]] + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1> + } + + // CHECK: return %[[loop]]#0, %[[loop]]#1 + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +} + +// ----- + +// Both the loop condition and the loop buffer yield non-equivalent buffers. + +// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body( +// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> +func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // These allocation used to be inside the scf.while loop, but they were + // hoisted. + // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = memref.load %[[w0]] + // CHECK: memref.copy %[[w1]], %[[a3]] + // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] + // CHECK: memref.copy %[[w0]], %[[a2]] + // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] + // CHECK: scf.condition(%[[condition]]) %[[casted3]], %[[casted2]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: memref.store %{{.*}}, %[[b0]] + // CHECK: memref.copy %[[b1]], %[[a1]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] + // CHECK: memref.copy %[[b0]], %[[a0]] + // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] + // CHECK: scf.yield %[[casted1]], %[[casted0]] + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1> + } + + // CHECK-DAG: memref.dealloc %[[a0]] + // CHECK-DAG: memref.dealloc %[[a1]] + // CHECK: return %[[loop]]#0, %[[loop]]#1 + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +}