diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -53,15 +53,15 @@ DiagnosedSilenceableFailure apply(TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); - ArrayRef payload = state.getPayloadOps(operandHandle); - if (payload.size() != 1) { + auto payload = state.getPayloadOps(operandHandle); + if (!llvm::hasSingleElement(payload)) { return emitDefiniteFailure(this->getOperation()->getLoc()) << "SingleOpMatchOpTrait requires the operand handle to point to " "a single payload op"; } return cast(this->getOperation()) - .matchOperation(payload[0], results, state); + .matchOperation(*payload.begin(), results, state); } void getEffects(SmallVectorImpl &effects) { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -22,6 +22,36 @@ class TransformOpInterface; class TransformResults; +class TransformState; + +using Param = Attribute; +using MappedValue = llvm::PointerUnion; + +namespace detail { +/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait +/// to either the list of operations associated with its operand or the root of +/// the payload IR, depending on what is available in the context. +LogicalResult +mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op, Region ®ion); + +/// Verification hook for PossibleTopLevelTransformOpTrait. +LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); + +/// Verification hook for TransformOpInterface. +LogicalResult verifyTransformOpInterface(Operation *op); + +/// Populates `mappings` with mapped values associated with the given transform +/// IR values in the given `state`. +void prepareValueMappings( + SmallVectorImpl> &mappings, + ValueRange values, const transform::TransformState &state); + +/// Populates `results` with payload associations that match exactly those of +/// the operands to `block`'s terminator. +void forwardTerminatorOperands(Block *block, transform::TransformState &state, + transform::TransformResults &results); +} // namespace detail /// Options controlling the application of transform operations by the /// TransformState. @@ -46,9 +76,6 @@ bool expensiveChecksEnabled = true; }; -using Param = Attribute; -using MappedValue = llvm::PointerUnion; - /// Entry point to the Transform dialect infrastructure. Applies the /// transformation specified by `transform` to payload IR contained in /// `payloadRoot`. The `transform` operation may contain other operations that @@ -96,6 +123,11 @@ using Param = transform::Param; private: + friend void + detail::forwardTerminatorOperands(Block *block, + transform::TransformState &state, + transform::TransformResults &results); + /// Mapping between a Value in the transform IR and the corresponding set of /// operations in the payload IR. using TransformOpMapping = DenseMap>; @@ -140,9 +172,16 @@ return topLevelMappedValues[position]; } - /// Returns the list of ops that the given transform IR value corresponds to. - /// This is helpful for transformations that apply to a particular handle. - ArrayRef getPayloadOps(Value value) const; + /// Returns an iterator that enumerates all ops that the given transform IR + /// value corresponds to at the time when this function is called. Ops may be + /// erased while iterating; erased ops are not enumerated. This function is + /// helpful for transformations that apply to a particular handle. + auto getPayloadOps(Value value) const { + // When ops are replaced/erased, they are replaced with nullptr (until + // the data structure is compacted). Do not enumerate these ops. + return llvm::make_filter_range(getPayloadOpsView(value), + [](Operation *op) { return op != nullptr; }); + } /// Returns the list of parameters that the given transform IR value /// corresponds to. @@ -407,6 +446,12 @@ LogicalResult updateStateFromResults(const TransformResults &results, ResultRange opResults); + /// Returns a list of all ops that the given transform IR value corresponds to + /// at the time when this function is called. In case an op was erased, the + /// returned list contains nullptr. This function is helpful for + /// transformations that apply to a particular handle. + ArrayRef getPayloadOpsView(Value value) const; + /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -544,6 +589,8 @@ /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; + SmallVector opHandlesToCompact; + /// Extensions attached to the TransformState, identified by the TypeID of /// their type. Only one extension of any given type is allowed. DenseMap> extensions; @@ -595,7 +642,25 @@ /// corresponds to the given list of payload IR ops. Each result must be set /// by the transformation exactly once in case of transformation succeeding. /// The value must have a type implementing TransformHandleTypeInterface. - void set(OpResult value, ArrayRef ops); + template void set(OpResult value, Range ops) { + int64_t position = value.getResultNumber(); + assert(position < static_cast(operations.size()) && + "setting results for a non-existent handle"); + assert(operations[position].data() == nullptr && "results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + operations.replace(position, ops); + } + + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given payload IR op. Each result must be set by the + /// transformation exactly once in case of transformation succeeding. The + /// value must have a type implementing TransformHandleTypeInterface. + void set(OpResult value, Operation *op) { + set(value, ArrayRef(op)); + } /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of parameters. Each result must be set by @@ -682,32 +747,6 @@ return RegionScope(*this, region, RegionScope::Isolated()); } -namespace detail { -/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait -/// to either the list of operations associated with its operand or the root of -/// the payload IR, depending on what is available in the context. -LogicalResult -mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op, Region ®ion); - -/// Verification hook for PossibleTopLevelTransformOpTrait. -LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); - -/// Verification hook for TransformOpInterface. -LogicalResult verifyTransformOpInterface(Operation *op); - -/// Populates `mappings` with mapped values associated with the given transform -/// IR values in the given `state`. -void prepareValueMappings( - SmallVectorImpl> &mappings, - ValueRange values, const transform::TransformState &state); - -/// Populates `results` with payload associations that match exactly those of -/// the operands to `block`'s terminator. -void forwardTerminatorOperands(Block *block, transform::TransformState &state, - transform::TransformResults &results); -} // namespace detail - /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1069,9 +1108,9 @@ /// - a concrete Op class, in which case a check is performed whether /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. -template +template DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, ArrayRef targets, +applyTransformToEach(TransformOpTy transformOp, Range targets, SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< @@ -1133,14 +1172,13 @@ mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( TransformResults &transformResults, TransformState &state) { - ArrayRef targets = - state.getPayloadOps(this->getOperation()->getOperand(0)); + auto targets = state.getPayloadOps(this->getOperation()->getOperand(0)); // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to // propagate gracefully. // In this case, we fill all results with an empty vector. - if (targets.empty()) { + if (std::begin(targets) == std::end(targets)) { SmallVector emptyPayload; SmallVector emptyParams; for (OpResult r : this->getOperation()->getResults()) { @@ -1157,7 +1195,6 @@ // Step 2. Call applyToOne on each target and record newly produced ops in its // corresponding results entry. SmallVector results; - results.reserve(targets.size()); DiagnosedSilenceableFailure result = detail::applyTransformToEach( cast(this->getOperation()), targets, results, state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -75,7 +75,7 @@ for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), getUpperBounds())) { Value handle = std::get<0>(it); - ArrayRef boundedValueOps = state.getPayloadOps(handle); + auto boundedValueOps = state.getPayloadOps(handle); for (Operation *op : boundedValueOps) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { auto diag = @@ -104,8 +104,8 @@ } // Transform all targets. - ArrayRef targets = state.getPayloadOps(getTarget()); - for (Operation *target : targets) { + SmallVector targets; + for (Operation *target : state.getPayloadOps(getTarget())) { if (!isa(target)) { auto diag = emitDefiniteFailure() << "target must be affine.min or affine.max"; @@ -118,6 +118,7 @@ diag.attachNote(target->getLoc()) << "target/constrained op"; return diag; } + targets.push_back(target); } SmallVector transformed; RewritePatternSet patterns(getContext()); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -41,7 +41,7 @@ options.setFunctionBoundaryTypeConversion( *getFunctionBoundaryTypeConversion()); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { if (!isa(target)) return emitSilenceableError() << "expected module or function target"; @@ -80,7 +80,7 @@ OneShotBufferizationOptions options; options.allowReturnAllocs = true; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { OneShotAnalysisState state(target, options); if (failed(analyzeOp(target, state))) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -92,17 +92,17 @@ result.push_back(ofr); continue; } - ArrayRef payloadOps = state.getPayloadOps(ofr.get()); - if (payloadOps.size() != 1) { + auto payloadOps = state.getPayloadOps(ofr.get()); + if (!llvm::hasSingleElement(payloadOps)) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "handle must be mapped to exactly one payload op"; diag.attachNote(ofr.get().getLoc()) - << "mapped to " << payloadOps.size() << " payload ops"; + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; return diag; } - Operation *op = payloadOps[0]; + Operation *op = *payloadOps.begin(); if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() @@ -125,7 +125,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, Value packedHandle) { - ArrayRef payloadOps = state.getPayloadOps(packedHandle); + auto payloadOps = state.getPayloadOps(packedHandle); for (Operation *op : payloadOps) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = @@ -208,16 +208,14 @@ /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. +template static LogicalResult applyTilingToAll( - RewriterBase &rewriter, Operation *transformOp, - ArrayRef payloadOps, unsigned numLoops, - transform::TransformResults &transformResults, + RewriterBase &rewriter, Operation *transformOp, Range payloadOps, + unsigned numLoops, transform::TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); - for (unsigned int i = 0; i < numLoops; ++i) - loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); @@ -584,20 +582,20 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; - ArrayRef producerOps = state.getPayloadOps(getProducerOp()); + auto producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. - if (producerOps.empty()) { + if (producerOps.begin() == producerOps.end()) { results.set(getFusedOp().cast(), SmallVector{}); return DiagnosedSilenceableFailure::success(); } - ArrayRef containingOps = state.getPayloadOps(getContainingOp()); - if (containingOps.size() != 1) { + auto containingOps = state.getPayloadOps(getContainingOp()); + if (!llvm::hasSingleElement(containingOps)) { return emitDefiniteFailure() << "requires exactly one containing_op handle (got " - << containingOps.size() << ")"; + << llvm::range_size(containingOps) << ")"; } - Operation *containingOp = containingOps.front(); + Operation *containingOp = *containingOps.begin(); // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. @@ -817,8 +815,8 @@ strs.insert(getOps()->getAsValueRange().begin(), getOps()->getAsValueRange().end()); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - if (payloadOps.size() != 1) { + auto payloadOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payloadOps)) { return emitDefiniteFailure("requires exactly one target handle"); } @@ -864,7 +862,7 @@ return; }; - payloadOps.front()->walk(matchFun); + (*payloadOps.begin())->walk(matchFun); results.set(getResult().cast(), res); return DiagnosedSilenceableFailure::success(); } @@ -1003,18 +1001,19 @@ DiagnosedSilenceableFailure transform::PackOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOps = state.getPayloadOps(getTarget()); + auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. - if (targetOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); + if (targetOps.begin() == targetOps.end()) { + transformResults.set(getPackedOp().cast(), + ArrayRef({})); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. - auto linalgOp = dyn_cast(targetOps.front()); - if (targetOps.size() != 1 || !linalgOp) { + auto linalgOp = dyn_cast(*targetOps.begin()); + if (!llvm::hasSingleElement(targetOps) || !linalgOp) { return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " - << targetOps.size() << ")"; + << llvm::range_size(targetOps) << ")"; } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { @@ -1211,7 +1210,7 @@ DiagnosedSilenceableFailure PackGreedilyOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOpsView = state.getPayloadOps(getTarget()); + auto targetOpsView = state.getPayloadOps(getTarget()); // Store payload ops into a separate SmallVector because the TrackingListener // removes erased ops from the transform state. SmallVector targetOps(targetOpsView.begin(), @@ -1317,34 +1316,38 @@ DiagnosedSilenceableFailure transform::PackTransposeOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef packOrUnpackOps = - state.getPayloadOps(getTargetPackOrUnPackOp()); - ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); + auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); + auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. - if (packOrUnpackOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); - transformResults.set(getPackOp().cast(), {}); - transformResults.set(getUnPackOp().cast(), {}); + if (packOrUnpackOps.begin() == packOrUnpackOps.end()) { + transformResults.set(getPackedOp().cast(), + ArrayRef({})); + transformResults.set(getPackOp().cast(), + ArrayRef({})); + transformResults.set(getUnPackOp().cast(), + ArrayRef({})); return DiagnosedSilenceableFailure::success(); } // Step 2. Bunch of runtime sanity check and error messages. // Step 2.1. Fail on multi-op handles. - if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { - return emitSilenceableError() << "requires target to map to exactly 1 " - "packing op and 1 packed op (" - << "got " << packOrUnpackOps.size() << " and " - << linalgOps.size() << ")"; + if (!llvm::hasSingleElement(packOrUnpackOps) || + !llvm::hasSingleElement(linalgOps)) { + return emitSilenceableError() + << "requires target to map to exactly 1 " + "packing op and 1 packed op (" + << "got " << llvm::range_size(packOrUnpackOps) << " and " + << llvm::range_size(linalgOps) << ")"; } // Step 2.2. Fail on wrong type. - auto packOp = dyn_cast(packOrUnpackOps.front()); - auto unPackOp = dyn_cast(packOrUnpackOps.front()); + auto packOp = dyn_cast(*packOrUnpackOps.begin()); + auto unPackOp = dyn_cast(*packOrUnpackOps.begin()); if ((!packOp && !unPackOp)) { return emitSilenceableError() << "requires target to map to a " "tensor.pack or tensor.unpack"; } - LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); + LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin()); if (!linalgOpTarget) return emitSilenceableError() << "requires a LinalgOp target"; @@ -1400,14 +1403,16 @@ assert(succeeded(res) && "unexpected packTranspose failure"); // Step 4. Return results. - transformResults.set(getPackOp().cast(), {res->transposedPackOp}); + transformResults.set(getPackOp().cast(), + ArrayRef({res->transposedPackOp})); transformResults.set(getPackedOp().cast(), - {res->transposedLinalgOp}); + ArrayRef({res->transposedLinalgOp})); if (unPackOp) { transformResults.set(getUnPackOp().cast(), - {res->transposedUnPackOp}); + ArrayRef({res->transposedUnPackOp})); } else { - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(getUnPackOp().cast(), + ArrayRef({})); } return DiagnosedSilenceableFailure::success(); @@ -1527,16 +1532,17 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOps = state.getPayloadOps(getTarget()); - ArrayRef loopOps = state.getPayloadOps(getLoop()); - if (targetOps.size() != 1 || loopOps.size() != 1) { + auto targetOps = state.getPayloadOps(getTarget()); + auto loopOps = state.getPayloadOps(getLoop()); + if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) { return emitDefiniteFailure() << "requires exactly one target and one loop handle (got " - << targetOps.size() << " and " << loopOps.size() << ")"; + << llvm::range_size(targetOps) << " and " + << llvm::range_size(loopOps) << ")"; } - auto padOp = dyn_cast_or_null(targetOps.front()); - auto loopOp = dyn_cast_or_null(loopOps.front()); + auto padOp = dyn_cast_or_null(*targetOps.begin()); + auto loopOp = dyn_cast_or_null(*loopOps.begin()); if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; @@ -1686,7 +1692,7 @@ DiagnosedSilenceableFailure transform::ReplaceOp::apply(TransformResults &transformResults, TransformState &state) { - ArrayRef payload = state.getPayloadOps(getTarget()); + auto payload = state.getPayloadOps(getTarget()); // Check for invalid targets. for (Operation *target : payload) { @@ -1821,7 +1827,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. - ArrayRef payload = state.getPayloadOps(getTarget()); + SmallVector payload = + llvm::to_vector(state.getPayloadOps(getTarget())); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; @@ -2206,8 +2213,9 @@ TransformState &state) { ArrayRef tileSizes = getStaticSizes(); - ArrayRef targets = state.getPayloadOps(getTarget()); - SmallVector> dynamicSizeProducers; + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + SmallVector> dynamicSizeProducers; SmallVector> paramSizes; dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); @@ -2233,7 +2241,8 @@ continue; } paramSizes.push_back({}); - dynamicSizeProducers.push_back(state.getPayloadOps(transformValue)); + dynamicSizeProducers.push_back( + llvm::to_vector(state.getPayloadOps(transformValue))); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = @@ -2549,7 +2558,7 @@ TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); auto transformOp = cast(getOperation()); - ArrayRef targetsView = state.getPayloadOps(getTarget()); + auto targetsView = state.getPayloadOps(getTarget()); // Store payload ops into a separate SmallVector because the TrackingListener // removes erased ops from the transform state. SmallVector targets(targetsView.begin(), targetsView.end()); @@ -2577,6 +2586,7 @@ if (!status.succeeded()) return status; + // TODO: Put loop into this function. DiagnosedSilenceableFailure diag = tileToForallOpImpl(rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes, getMapping(), tileOps, tiledOps); @@ -2661,12 +2671,13 @@ TransformState &state) { ArrayRef tileSizes = getStaticSizes(); - ArrayRef targets = state.getPayloadOps(getTarget()); - SmallVector> dynamicSizeProducers; + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( - state.getPayloadOps(dynamicSizeProducerHandle)); + llvm::to_vector(state.getPayloadOps(dynamicSizeProducerHandle))); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = @@ -2893,8 +2904,8 @@ mlir::transform::TransformState &state) { TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); - ArrayRef targets = state.getPayloadOps(getTarget()); - if (targets.empty()) + auto targets = state.getPayloadOps(getTarget()); + if (targets.begin() == targets.end()) return DiagnosedSilenceableFailure::success(); SmallVector vectorSizes; @@ -2905,16 +2916,16 @@ continue; } - ArrayRef szPayloads = state.getPayloadOps(sz.get()); - if (szPayloads.size() != 1) { + auto szPayloads = state.getPayloadOps(sz.get()); + if (!llvm::hasSingleElement(szPayloads)) { auto diag = this->emitOpError( "requires vector size handle that is mapped to 1 payload op"); diag.attachNote(sz.get().getLoc()) - << "mapped to " << szPayloads.size() << " payload ops"; + << "mapped to " << llvm::range_size(szPayloads) << " payload ops"; return DiagnosedSilenceableFailure::definiteFailure(); } - Operation *szPayloadOp = szPayloads[0]; + Operation *szPayloadOp = *szPayloads.begin(); if (szPayloadOp->getNumResults() != 1 || !szPayloadOp->getResult(0).getType().isIndex()) { auto diag = this->emitOpError( diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -34,7 +34,7 @@ transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); IRRewriter rewriter(getContext()); for (auto *op : payloadOps) { bool canApplyMultiBuffer = true; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -53,7 +53,7 @@ Operation *transform::TransformState::getTopLevel() const { return topLevel; } ArrayRef -transform::TransformState::getPayloadOps(Value value) const { +transform::TransformState::getPayloadOpsView(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); assert( @@ -356,18 +356,8 @@ // TODO: consider invalidating the handles to nested objects here. - // If replacing with null, that is erasing the mapping, drop the mapping - // between the handles and the IR objects and return. - if (!replacement) { - for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); - dropMappingEntry(mappings.direct, handle, op); - } - return success(); - } - #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (options.getExpensiveChecksEnabled()) { + if (replacement && options.getExpensiveChecksEnabled()) { auto insertion = cachedNames.insert({replacement, replacement->getName()}); if (!insertion.second) { assert(insertion.first->second == replacement->getName() && @@ -376,8 +366,16 @@ } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - // Otherwise, replace the pointed-to object of all handles while preserving - // their relative order. First, replace the mapped operation if present. + // Replace the pointed-to object of all handles with nullptr. This ensures + // payload op iterators stay in a consistent state. (Removing an op from the + // mapping would be problematic because removing an element from an array + // invalidates iterators; merely changing the value of elements does not.) + // nullptrs are not enumerated by `getPayloadOps`. They are removed at the end + // of each transform op application. + // + // Replacement ops are appended to the end of the list. Note that, therefore, + // the relative order of ops is not preserved. I.e., the replacement op is + // always at the end of the list of payload ops. for (Value handle : opHandles) { Mappings &mappings = getMapping(handle); auto it = mappings.direct.find(handle); @@ -386,11 +384,26 @@ SmallVector &association = it->getSecond(); // Note that an operation may be associated with the handle more than once. + int64_t counter = 0; for (Operation *&mapped : association) { - if (mapped == op) - mapped = replacement; + if (mapped == op) { + mapped = nullptr; + ++counter; + } + } + assert(counter > 0 && "inconsistent mapping state"); + + if (replacement) { + // Add replacement op to the end of the list. + for (int64_t i = 0; i < counter; ++i) + association.push_back(replacement); + mappings.reverse[replacement].push_back(handle); } - mappings.reverse[replacement].push_back(handle); + + // nullptr was added to the mapping data structure. For efficiency reasons, + // compact the list of payload ops at the end of the transformation by + // removing all nullptrs. + opHandlesToCompact.push_back(handle); } return success(); @@ -644,7 +657,7 @@ FULL_LDBG("----found consume effect -> SKIP\n"); if (target.get().getType().isa()) { FULL_LDBG("----recordOpHandleInvalidation\n"); - ArrayRef payloadOps = getPayloadOps(target.get()); + ArrayRef payloadOps = getPayloadOpsView(target.get()); recordOpHandleInvalidation(target, payloadOps); } else if (target.get() .getType() @@ -696,6 +709,14 @@ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, DBGS() << "Top-level payload before application:\n" << *getTopLevel() << "\n"); + auto compactOpHandles = llvm::make_scope_exit([this] { + // Remove all nullptr ops from handles that had replacements/erasures. + for (Value handle : opHandlesToCompact) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.direct, handle, nullptr); + } + opHandlesToCompact.clear(); + }); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( @@ -721,7 +742,7 @@ FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( - getPayloadOps(operand.get()), transform, + getPayloadOpsView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { FULL_LDBG("----FAILED\n"); @@ -987,19 +1008,6 @@ values.appendEmptyRows(numSegments); } -void transform::TransformResults::set(OpResult value, - ArrayRef ops) { - int64_t position = value.getResultNumber(); - assert(position < static_cast(operations.size()) && - "setting results for a non-existent handle"); - assert(operations[position].data() == nullptr && "results already set"); - assert(params[position].data() == nullptr && - "another kind of results already set"); - assert(values[position].data() == nullptr && - "another kind of results already set"); - operations.replace(position, ops); -} - void transform::TransformResults::setParams( OpResult value, ArrayRef params) { int64_t position = value.getResultNumber(); @@ -1221,7 +1229,7 @@ llvm::zip(block->getTerminator()->getOperands(), block->getParentOp()->getOpResults())) { if (result.getType().isa()) { - results.set(result, state.getPayloadOps(terminatorOperand)); + results.set(result, state.getPayloadOpsView(terminatorOperand)); } else if (result.getType() .isa()) { results.setValues(result, state.getPayloadValues(terminatorOperand)); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -310,7 +310,7 @@ static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { for (const auto &res : block->getParentOp()->getOpResults()) - results.set(res, {}); + results.set(res, ArrayRef({})); } DiagnosedSilenceableFailure @@ -785,7 +785,7 @@ DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); SmallVector> resultOps(getNumResults(), {}); for (Operation *op : payloadOps) { @@ -803,8 +803,7 @@ // Append yielded payload ops to result list (if any). for (unsigned i = 0; i < getNumResults(); ++i) { - ArrayRef yieldedOps = - state.getPayloadOps(getYieldOp().getOperand(i)); + auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); } } @@ -900,16 +899,16 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - if (payloadOps.empty()) { - results.set(getResult().cast(), {}); + auto payloadOps = state.getPayloadOps(getTarget()); + if (payloadOps.begin() == payloadOps.end()) { + results.set(getResult().cast(), ArrayRef({})); return DiagnosedSilenceableFailure::success(); } - if (payloadOps.size() != 1) + if (!llvm::hasSingleElement(payloadOps)) return emitDefiniteFailure() << "handle must be mapped to exactly one payload op"; - Operation *target = payloadOps.front(); + Operation *target = *payloadOps.begin(); if (target->getNumResults() <= resultNumber) return emitDefiniteFailure() << "result number overflow"; results.set(getResult().cast(), @@ -1504,8 +1503,9 @@ DiagnosedSilenceableFailure transform::SplitHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { - int64_t numResultHandles = - getHandle() ? state.getPayloadOps(getHandle()).size() : 0; + SmallVector payloadOps = + llvm::to_vector(state.getPayloadOps(getHandle())); + int64_t numResultHandles = getHandle() ? payloadOps.size() : 0; int64_t expectedNumResultHandles = getNumResultHandles(); if (numResultHandles != expectedNumResultHandles) { // Empty input handle corner case: always propagates empty handles in both @@ -1520,7 +1520,7 @@ << " handles"; } // Normal successful case. - for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle()))) + for (const auto &en : llvm::enumerate(payloadOps)) results.set(getResults()[en.index()].cast(), en.value()); return DiagnosedSilenceableFailure::success(); } @@ -1569,11 +1569,12 @@ DiagnosedSilenceableFailure transform::ReplicateOp::apply(transform::TransformResults &results, transform::TransformState &state) { - unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); + unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); if (handle.getType().isa()) { - ArrayRef current = state.getPayloadOps(handle); + SmallVector current = + llvm::to_vector(state.getPayloadOps(handle)); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) @@ -2006,7 +2007,7 @@ } llvm::outs() << "]]]\n"; - ArrayRef targets = state.getPayloadOps(getTarget()); + auto targets = state.getPayloadOps(getTarget()); for (Operation *target : targets) llvm::outs() << *target << "\n"; diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): @@ -1549,3 +1549,23 @@ return } } + +// ----- + +// CHECK-LABEL: func @test_tracked_rewrite() { +// CHECK-NEXT: "test.foo"() {original_op = "test.bar"} +// CHECK-NEXT: "test.foo"() {original_op = "test.bar"} +// CHECK-NEXT: "test.foo"() {original_op = "test.bar"} +// CHECK-NEXT: } +func.func @test_tracked_rewrite() { + %0 = "test.bar"() : () -> (i1) + %1 = "test.bar"() : () -> (i1) + %2 = "test.bar"() : () -> (i1) +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %bar = transform.structured.match ops{["test.bar"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-remark @below {{1 iterations}} + transform.test_tracked_rewrite %bar : (!transform.any_op) -> () +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -109,10 +110,12 @@ mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast(), - getOperation()->getOperand(0).getDefiningOp()); + results.set( + getResult().cast(), + ArrayRef(getOperation()->getOperand(0).getDefiningOp())); } else { - results.set(getResult().cast(), getOperation()); + results.set(getResult().cast(), + ArrayRef(getOperation())); } return DiagnosedSilenceableFailure::success(); } @@ -191,12 +194,13 @@ DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payload = state.getPayloadOps(getOperand()); - assert(payload.size() == 1 && "expected a single target op"); - if (payload[0]->getName().getStringRef() != getOpKind()) { + auto payload = state.getPayloadOps(getOperand()); + assert(llvm::hasSingleElement(payload) && "expected a single target op"); + if ((*payload.begin())->getName().getStringRef() != getOpKind()) { return emitSilenceableError() << "op expected the operand to be associated a payload op of kind " - << getOpKind() << " got " << payload[0]->getName().getStringRef(); + << getOpKind() << " got " + << (*payload.begin())->getName().getStringRef(); } emitRemark() << "succeeded"; @@ -230,7 +234,7 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payload = state.getPayloadOps(getOperand()); + auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) op->emitRemark() << getMessage(); @@ -313,11 +317,12 @@ if (!extension) return emitDefiniteFailure("TestTransformStateExtension missing"); - if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), - getOperation()))) + if (failed(extension->updateMapping( + *state.getPayloadOps(getOperand()).begin(), getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) - results.set(getResult(0).cast(), getOperation()); + results.set(getResult(0).cast(), + ArrayRef(getOperation())); return DiagnosedSilenceableFailure::success(); } @@ -337,7 +342,7 @@ DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(getResult().cast(), reversedOps); return DiagnosedSilenceableFailure::success(); @@ -431,7 +436,7 @@ transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; - emitRemark() << state.getPayloadOps(getHandle()).size(); + emitRemark() << llvm::range_size(state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } @@ -598,11 +603,11 @@ } else if (getFirstResultIsNull()) { results.push_back(nullptr); } else { - results.push_back(state.getPayloadOps(getIn()).front()); + results.push_back(*state.getPayloadOps(getIn()).begin()); } if (getSecondResultIsHandle()) { - results.push_back(state.getPayloadOps(getIn()).front()); + results.push_back(*state.getPayloadOps(getIn()).begin()); } else { results.push_back(builder.getI64IntegerAttr(42)); } @@ -666,6 +671,42 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestTrackedRewriteOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + transform::TrackingListener listener(state, *this); + IRRewriter rewriter(getContext(), &listener); + int64_t numIterations = 0; + + // `getPayloadOps` returns an iterator that skips ops that are erased in the + // loop body. Replacement ops are not enumerated. + for (Operation *op : state.getPayloadOps(getIn())) { + ++numIterations; + rewriter.setInsertionPointToEnd(op->getBlock()); + + // Erase all payload ops. The outer loop should have only one iteration. + for (Operation *op : state.getPayloadOps(getIn())) { + SmallVector attributes; + attributes.emplace_back(rewriter.getStringAttr("original_op"), + op->getName().getIdentifier()); + OperationState opState(op->getLoc(), "test.foo", + /*operands=*/ValueRange(), + /*types=*/op->getResultTypes(), attributes); + Operation *newOp = rewriter.create(opState); + rewriter.replaceOp(op, newOp->getResults()); + } + } + + emitRemark() << numIterations << " iterations"; + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -449,4 +449,14 @@ let cppNamespace = "::mlir::test"; } +def TestTrackedRewriteOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$in); + let results = (outs); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD