diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -64,6 +64,10 @@ /// outside of test cases. TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot); + +/// Returns all operands that are handles and being consumed by the given op. +SmallVector +getConsumedHandleOpOperands(transform::TransformOpInterface transformOp); } // namespace detail /// Options controlling the application of transform operations by the diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -100,6 +100,11 @@ diag.attachNote(target->getLoc()) << "when applied to this op"; return diag; } + + /// Returns all operands that are handles and being consumed by this op. + ::llvm::SmallVector getConsumedHandleOpOperands() { + return ::mlir::transform::detail::getConsumedHandleOpOperands($_op); + } }]; let verify = [{ diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -843,21 +843,8 @@ } // Find which operands are consumed. - DenseSet consumedOperands; - auto memEffectInterface = - cast(transform.getOperation()); - SmallVector effects; - for (OpOperand &target : transform->getOpOperands()) { - effects.clear(); - memEffectInterface.getEffectsOnValue(target.get(), effects); - if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa( - effect.getResource()) && - isa(effect.getEffect()); - })) { - consumedOperands.insert(target.getOperandNumber()); - } - } + SmallVector consumedOperands = + transform.getConsumedHandleOpOperands(); // Remember the results of the payload ops associated with the consumed // op handles or the ops defining the value handles so we can drop the @@ -869,8 +856,8 @@ #if LLVM_ENABLE_ABI_BREAKING_CHECKS DenseSet consumedPayloadOps; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - for (unsigned index : consumedOperands) { - Value operand = transform->getOperand(index); + for (OpOperand *opOperand : consumedOperands) { + Value operand = opOperand->get(); if (llvm::isa(operand.getType())) { for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); @@ -901,7 +888,7 @@ DiagnosedDefiniteFailure diag = emitDefiniteFailure(transform->getLoc()) << "unexpectedly consumed a value that is not a handle as operand #" - << index; + << opOperand->getOperandNumber(); diag.attachNote(operand.getLoc()) << "value defined here with type " << operand.getType(); return diag; @@ -923,8 +910,8 @@ // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. - for (unsigned index : consumedOperands) { - Value operand = transform->getOperand(index); + for (OpOperand *opOperand : consumedOperands) { + Value operand = opOperand->get(); if (llvm::isa(operand.getType())) { forgetMapping(operand, origOpFlatResults); } else if (llvm::isa( @@ -1593,6 +1580,27 @@ // Utilities for TransformOpInterface. //===----------------------------------------------------------------------===// +SmallVector transform::detail::getConsumedHandleOpOperands( + TransformOpInterface transformOp) { + SmallVector consumedOperands; + consumedOperands.reserve(transformOp->getNumOperands()); + auto memEffectInterface = + cast(transformOp.getOperation()); + SmallVector effects; + for (OpOperand &target : transformOp->getOpOperands()) { + effects.clear(); + memEffectInterface.getEffectsOnValue(target.get(), effects); + if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { + return isa( + effect.getResource()) && + isa(effect.getEffect()); + })) { + consumedOperands.push_back(&target); + } + } + return consumedOperands; +} + LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { auto iface = cast(op); SmallVector effects;