diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -555,7 +555,8 @@ ArrayRef payloadOperations); /// Replaces the given payload op with another op. If the replacement op is - /// null, removes the association of the payload op with its handle. + /// null, removes the association of the payload op with its handle. Returns + /// failure if the op is not associated with any handle. /// /// Note: This function does not update value handles. None of the original /// op's results are allowed to be mapped to any value handle. @@ -563,7 +564,7 @@ /// Replaces the given payload value with another value. If the replacement /// value is null, removes the association of the payload value with its - /// handle. + /// handle. Returns failure if the value is not associated with any handle. LogicalResult replacePayloadValue(Value value, Value replacement); /// Records handle invalidation reporters into `newlyInvalidated`. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -338,15 +338,6 @@ LogicalResult transform::TransformState::replacePayloadOp(Operation *op, Operation *replacement) { - // Drop the mapping between the op and all handles that point to it. Don't - // care if there are on such handles. - SmallVector opHandles; - (void)getHandlesForPayloadOp(op, opHandles); - for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); - dropMappingEntry(mappings.reverse, op, handle); - } - #ifndef NDEBUG for (Value opResult : op->getResults()) { SmallVector valueHandles; @@ -364,6 +355,16 @@ } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + // Drop the mapping between the op and all handles that point to it. Fail if + // there are no handles. + SmallVector opHandles; + if (failed(getHandlesForPayloadOp(op, opHandles))) + return failure(); + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.reverse, op, handle); + } + // TODO: consider invalidating the handles to nested objects here. #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -411,7 +412,8 @@ LogicalResult transform::TransformState::replacePayloadValue(Value value, Value replacement) { SmallVector valueHandles; - (void)getHandlesForPayloadValue(value, valueHandles); + if (failed(getHandlesForPayloadValue(value, valueHandles))) + return failure(); for (Value handle : valueHandles) { Mappings &mappings = getMapping(handle); @@ -1064,10 +1066,6 @@ LogicalResult transform::TransformState::Extension::replacePayloadOp(Operation *op, Operation *replacement) { - SmallVector handles; - if (failed(state.getHandlesForPayloadOp(op, handles))) - return failure(); - // TODO: we may need to invalidate handles to operations and values nested in // the operation being replaced. return state.replacePayloadOp(op, replacement); @@ -1076,10 +1074,6 @@ LogicalResult transform::TransformState::Extension::replacePayloadValue(Value value, Value replacement) { - SmallVector handles; - if (failed(state.getHandlesForPayloadValue(value, handles))) - return failure(); - return state.replacePayloadValue(value, replacement); }