diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -53,15 +53,15 @@ DiagnosedSilenceableFailure apply(TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); - ArrayRef payload = state.getPayloadOps(operandHandle); - if (payload.size() != 1) { + auto payload = state.getPayloadOps(operandHandle); + if (!llvm::hasSingleElement(payload)) { return emitDefiniteFailure(this->getOperation()->getLoc()) << "SingleOpMatchOpTrait requires the operand handle to point to " "a single payload op"; } return cast(this->getOperation()) - .matchOperation(payload[0], results, state); + .matchOperation(*payload.begin(), results, state); } void getEffects(SmallVectorImpl &effects) { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -22,6 +22,36 @@ class TransformOpInterface; class TransformResults; +class TransformState; + +using Param = Attribute; +using MappedValue = llvm::PointerUnion; + +namespace detail { +/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait +/// to either the list of operations associated with its operand or the root of +/// the payload IR, depending on what is available in the context. +LogicalResult +mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op, Region ®ion); + +/// Verification hook for PossibleTopLevelTransformOpTrait. +LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); + +/// Verification hook for TransformOpInterface. +LogicalResult verifyTransformOpInterface(Operation *op); + +/// Populates `mappings` with mapped values associated with the given transform +/// IR values in the given `state`. +void prepareValueMappings( + SmallVectorImpl> &mappings, + ValueRange values, const transform::TransformState &state); + +/// Populates `results` with payload associations that match exactly those of +/// the operands to `block`'s terminator. +void forwardTerminatorOperands(Block *block, transform::TransformState &state, + transform::TransformResults &results); +} // namespace detail /// Options controlling the application of transform operations by the /// TransformState. @@ -46,9 +76,6 @@ bool expensiveChecksEnabled = true; }; -using Param = Attribute; -using MappedValue = llvm::PointerUnion; - /// Entry point to the Transform dialect infrastructure. Applies the /// transformation specified by `transform` to payload IR contained in /// `payloadRoot`. The `transform` operation may contain other operations that @@ -140,9 +167,16 @@ return topLevelMappedValues[position]; } - /// Returns the list of ops that the given transform IR value corresponds to. - /// This is helpful for transformations that apply to a particular handle. - ArrayRef getPayloadOps(Value value) const; + /// Returns an iterator that enumerates all ops that the given transform IR + /// value corresponds to. Ops may be erased while iterating; erased ops are + /// not enumerated. This function is helpful for transformations that apply to + /// a particular handle. + auto getPayloadOps(Value value) const { + // When ops are replaced/erased, they are replaced with nullptr (until + // the data structure is compacted). Do not enumerate these ops. + return llvm::make_filter_range(getPayloadOpsView(value), + [](Operation *op) { return op != nullptr; }); + } /// Returns the list of parameters that the given transform IR value /// corresponds to. @@ -407,6 +441,12 @@ LogicalResult updateStateFromResults(const TransformResults &results, ResultRange opResults); + /// Returns a list of all ops that the given transform IR value corresponds to + /// at the time when this function is called. In case an op was erased, the + /// returned list contains nullptr. This function is helpful for + /// transformations that apply to a particular handle. + ArrayRef getPayloadOpsView(Value value) const; + /// Sets the payload IR ops associated with the given transform IR value /// (handle). A payload op may be associated multiple handles as long as /// at most one of them gets consumed by further transformations. @@ -540,10 +580,19 @@ LogicalResult checkAndRecordHandleInvalidation(TransformOpInterface transform); + /// Remove all nullptrs from op handles that were added by `replacePayloadOp`. + void compactOpHandles(); + /// The mappings between transform IR values and payload IR ops, aggregated by /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; + /// Op handles may be temporarily mapped to nullptr to avoid invalidating + /// payload op iterators. This set contains all op handles with nullptrs. + /// These handles are "compacted" (i.e., nullptrs removed) at the end of each + /// transform. + DenseSet opHandlesToCompact; + /// Extensions attached to the TransformState, identified by the TypeID of /// their type. Only one extension of any given type is allowed. DenseMap> extensions; @@ -595,7 +644,25 @@ /// corresponds to the given list of payload IR ops. Each result must be set /// by the transformation exactly once in case of transformation succeeding. /// The value must have a type implementing TransformHandleTypeInterface. - void set(OpResult value, ArrayRef ops); + template void set(OpResult value, Range &&ops) { + int64_t position = value.getResultNumber(); + assert(position < static_cast(operations.size()) && + "setting results for a non-existent handle"); + assert(operations[position].data() == nullptr && "results already set"); + assert(params[position].data() == nullptr && + "another kind of results already set"); + assert(values[position].data() == nullptr && + "another kind of results already set"); + operations.replace(position, std::forward(ops)); + } + + /// Indicates that the result of the transform IR op at the given position + /// corresponds to the given list of payload IR ops. Each result must be set + /// by the transformation exactly once in case of transformation succeeding. + /// The value must have a type implementing TransformHandleTypeInterface. + void set(OpResult value, std::initializer_list ops) { + set(value, ArrayRef(ops)); + } /// Indicates that the result of the transform IR op at the given position /// corresponds to the given list of parameters. Each result must be set by @@ -682,32 +749,6 @@ return RegionScope(*this, region, RegionScope::Isolated()); } -namespace detail { -/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait -/// to either the list of operations associated with its operand or the root of -/// the payload IR, depending on what is available in the context. -LogicalResult -mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op, Region ®ion); - -/// Verification hook for PossibleTopLevelTransformOpTrait. -LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); - -/// Verification hook for TransformOpInterface. -LogicalResult verifyTransformOpInterface(Operation *op); - -/// Populates `mappings` with mapped values associated with the given transform -/// IR values in the given `state`. -void prepareValueMappings( - SmallVectorImpl> &mappings, - ValueRange values, const transform::TransformState &state); - -/// Populates `results` with payload associations that match exactly those of -/// the operands to `block`'s terminator. -void forwardTerminatorOperands(Block *block, transform::TransformState &state, - transform::TransformResults &results); -} // namespace detail - /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1069,9 +1110,9 @@ /// - a concrete Op class, in which case a check is performed whether /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. -template +template DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, ArrayRef targets, +applyTransformToEach(TransformOpTy transformOp, Range &&targets, SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< @@ -1133,14 +1174,13 @@ mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( TransformResults &transformResults, TransformState &state) { - ArrayRef targets = - state.getPayloadOps(this->getOperation()->getOperand(0)); + auto targets = state.getPayloadOps(this->getOperation()->getOperand(0)); // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to // propagate gracefully. // In this case, we fill all results with an empty vector. - if (targets.empty()) { + if (std::empty(targets)) { SmallVector emptyPayload; SmallVector emptyParams; for (OpResult r : this->getOperation()->getResults()) { @@ -1157,7 +1197,6 @@ // Step 2. Call applyToOne on each target and record newly produced ops in its // corresponding results entry. SmallVector results; - results.reserve(targets.size()); DiagnosedSilenceableFailure result = detail::applyTransformToEach( cast(this->getOperation()), targets, results, state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -75,8 +75,7 @@ for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), getUpperBounds())) { Value handle = std::get<0>(it); - ArrayRef boundedValueOps = state.getPayloadOps(handle); - for (Operation *op : boundedValueOps) { + for (Operation *op : state.getPayloadOps(handle)) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { auto diag = emitDefiniteFailure() @@ -104,8 +103,8 @@ } // Transform all targets. - ArrayRef targets = state.getPayloadOps(getTarget()); - for (Operation *target : targets) { + SmallVector targets; + for (Operation *target : state.getPayloadOps(getTarget())) { if (!isa(target)) { auto diag = emitDefiniteFailure() << "target must be affine.min or affine.max"; @@ -118,6 +117,7 @@ diag.attachNote(target->getLoc()) << "target/constrained op"; return diag; } + targets.push_back(target); } SmallVector transformed; RewritePatternSet patterns(getContext()); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -41,7 +41,7 @@ options.setFunctionBoundaryTypeConversion( *getFunctionBoundaryTypeConversion()); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { if (!isa(target)) return emitSilenceableError() << "expected module or function target"; @@ -80,8 +80,7 @@ OneShotBufferizationOptions options; options.allowReturnAllocs = true; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - for (Operation *target : payloadOps) { + for (Operation *target : state.getPayloadOps(getTarget())) { OneShotAnalysisState state(target, options); if (failed(analyzeOp(target, state))) return mlir::emitSilenceableFailure(target->getLoc()) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -618,7 +618,7 @@ } Operation *firstUser = *result.getUsers().begin(); if (getAny()) { - results.set(cast(getResult()), firstUser); + results.set(cast(getResult()), {firstUser}); return DiagnosedSilenceableFailure::success(); } if (getSingle()) { @@ -626,7 +626,7 @@ return emitSilenceableError() << "more than one result user with single user requested"; } - results.set(cast(getResult()), firstUser); + results.set(cast(getResult()), {firstUser}); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -92,17 +92,17 @@ result.push_back(ofr); continue; } - ArrayRef payloadOps = state.getPayloadOps(ofr.get()); - if (payloadOps.size() != 1) { + auto payloadOps = state.getPayloadOps(ofr.get()); + if (!llvm::hasSingleElement(payloadOps)) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "handle must be mapped to exactly one payload op"; diag.attachNote(ofr.get().getLoc()) - << "mapped to " << payloadOps.size() << " payload ops"; + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; return diag; } - Operation *op = payloadOps[0]; + Operation *op = *payloadOps.begin(); if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() @@ -125,8 +125,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, Value packedHandle) { - ArrayRef payloadOps = state.getPayloadOps(packedHandle); - for (Operation *op : payloadOps) { + for (Operation *op : state.getPayloadOps(packedHandle)) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() @@ -208,16 +207,14 @@ /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. +template static LogicalResult applyTilingToAll( - RewriterBase &rewriter, Operation *transformOp, - ArrayRef payloadOps, unsigned numLoops, - transform::TransformResults &transformResults, + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + unsigned numLoops, transform::TransformResults &transformResults, function_ref(TilingInterface)> applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); - for (unsigned int i = 0; i < numLoops; ++i) - loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { auto tilingInterfaceOp = dyn_cast(target); @@ -578,19 +575,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; - ArrayRef producerOps = state.getPayloadOps(getProducerOp()); + auto producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. - if (producerOps.empty()) { + if (std::empty(producerOps)) { results.set(cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } - ArrayRef containingOps = state.getPayloadOps(getContainingOp()); - if (containingOps.size() != 1) { + auto containingOps = state.getPayloadOps(getContainingOp()); + if (!llvm::hasSingleElement(containingOps)) { return emitDefiniteFailure() << "requires exactly one containing_op handle (got " - << containingOps.size() << ")"; + << llvm::range_size(containingOps) << ")"; } - Operation *containingOp = containingOps.front(); + Operation *containingOp = *containingOps.begin(); // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. @@ -810,8 +807,8 @@ strs.insert(getOps()->getAsValueRange().begin(), getOps()->getAsValueRange().end()); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - if (payloadOps.size() != 1) { + auto payloadOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payloadOps)) { return emitDefiniteFailure("requires exactly one target handle"); } @@ -857,7 +854,7 @@ return; }; - payloadOps.front()->walk(matchFun); + (*payloadOps.begin())->walk(matchFun); results.set(cast(getResult()), res); return DiagnosedSilenceableFailure::success(); } @@ -996,18 +993,19 @@ DiagnosedSilenceableFailure transform::PackOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOps = state.getPayloadOps(getTarget()); + auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. - if (targetOps.empty()) { - transformResults.set(cast(getPackedOp()), {}); + if (std::empty(targetOps)) { + transformResults.set(cast(getPackedOp()), + ArrayRef({})); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. - auto linalgOp = dyn_cast(targetOps.front()); - if (targetOps.size() != 1 || !linalgOp) { + auto linalgOp = dyn_cast(*targetOps.begin()); + if (!llvm::hasSingleElement(targetOps) || !linalgOp) { return emitSilenceableError() << "requires target to map to exactly 1 LinalgOp (got " - << targetOps.size() << ")"; + << llvm::range_size(targetOps) << ")"; } // Fail on mismatched number of pack sizes. if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) { @@ -1030,7 +1028,7 @@ return emitDefiniteFailure("data tiling failed"); transformResults.set(cast(getPackedOp()), - maybeResult->packedLinalgOp.getOperation()); + {maybeResult->packedLinalgOp.getOperation()}); return DiagnosedSilenceableFailure::success(); } @@ -1204,16 +1202,10 @@ DiagnosedSilenceableFailure PackGreedilyOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOpsView = state.getPayloadOps(getTarget()); - // Store payload ops into a separate SmallVector because the TrackingListener - // removes erased ops from the transform state. - SmallVector targetOps(targetOpsView.begin(), - targetOpsView.end()); - SmallVector results; TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); - for (Operation *op : targetOps) { + for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) continue; @@ -1310,11 +1302,10 @@ DiagnosedSilenceableFailure transform::PackTransposeOp::apply(transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef packOrUnpackOps = - state.getPayloadOps(getTargetPackOrUnPackOp()); - ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); + auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); + auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. - if (packOrUnpackOps.empty()) { + if (std::empty(packOrUnpackOps)) { transformResults.set(cast(getPackedOp()), {}); transformResults.set(cast(getPackOp()), {}); transformResults.set(cast(getUnPackOp()), {}); @@ -1323,21 +1314,23 @@ // Step 2. Bunch of runtime sanity check and error messages. // Step 2.1. Fail on multi-op handles. - if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) { - return emitSilenceableError() << "requires target to map to exactly 1 " - "packing op and 1 packed op (" - << "got " << packOrUnpackOps.size() << " and " - << linalgOps.size() << ")"; + if (!llvm::hasSingleElement(packOrUnpackOps) || + !llvm::hasSingleElement(linalgOps)) { + return emitSilenceableError() + << "requires target to map to exactly 1 " + "packing op and 1 packed op (" + << "got " << llvm::range_size(packOrUnpackOps) << " and " + << llvm::range_size(linalgOps) << ")"; } // Step 2.2. Fail on wrong type. - auto packOp = dyn_cast(packOrUnpackOps.front()); - auto unPackOp = dyn_cast(packOrUnpackOps.front()); + auto packOp = dyn_cast(*packOrUnpackOps.begin()); + auto unPackOp = dyn_cast(*packOrUnpackOps.begin()); if ((!packOp && !unPackOp)) { return emitSilenceableError() << "requires target to map to a " "tensor.pack or tensor.unpack"; } - LinalgOp linalgOpTarget = dyn_cast(linalgOps.front()); + LinalgOp linalgOpTarget = dyn_cast(*linalgOps.begin()); if (!linalgOpTarget) return emitSilenceableError() << "requires a LinalgOp target"; @@ -1520,16 +1513,17 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( transform::TransformResults &transformResults, transform::TransformState &state) { - ArrayRef targetOps = state.getPayloadOps(getTarget()); - ArrayRef loopOps = state.getPayloadOps(getLoop()); - if (targetOps.size() != 1 || loopOps.size() != 1) { + auto targetOps = state.getPayloadOps(getTarget()); + auto loopOps = state.getPayloadOps(getLoop()); + if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) { return emitDefiniteFailure() << "requires exactly one target and one loop handle (got " - << targetOps.size() << " and " << loopOps.size() << ")"; + << llvm::range_size(targetOps) << " and " + << llvm::range_size(loopOps) << ")"; } - auto padOp = dyn_cast_or_null(targetOps.front()); - auto loopOp = dyn_cast_or_null(loopOps.front()); + auto padOp = dyn_cast_or_null(*targetOps.begin()); + auto loopOp = dyn_cast_or_null(*loopOps.begin()); if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; @@ -1543,13 +1537,13 @@ if (result->clonedLoopIvs.empty()) { transformResults.set(cast(getPackingLoop()), - result->hoistedPadOp.getOperation()); + {result->hoistedPadOp.getOperation()}); return DiagnosedSilenceableFailure::success(); } auto outerPackedLoop = scf::getForInductionVarOwner(result->clonedLoopIvs.front()); transformResults.set(cast(getPackingLoop()), - outerPackedLoop.getOperation()); + {outerPackedLoop.getOperation()}); return DiagnosedSilenceableFailure::success(); } @@ -1679,7 +1673,7 @@ DiagnosedSilenceableFailure transform::ReplaceOp::apply(TransformResults &transformResults, TransformState &state) { - ArrayRef payload = state.getPayloadOps(getTarget()); + auto payload = state.getPayloadOps(getTarget()); // Check for invalid targets. for (Operation *target : payload) { @@ -1814,7 +1808,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. - ArrayRef payload = state.getPayloadOps(getTarget()); + SmallVector payload = + llvm::to_vector(state.getPayloadOps(getTarget())); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; @@ -2199,8 +2194,9 @@ TransformState &state) { ArrayRef tileSizes = getStaticSizes(); - ArrayRef targets = state.getPayloadOps(getTarget()); - SmallVector> dynamicSizeProducers; + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + SmallVector> dynamicSizeProducers; SmallVector> paramSizes; dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); @@ -2226,7 +2222,8 @@ continue; } paramSizes.push_back({}); - dynamicSizeProducers.push_back(state.getPayloadOps(transformValue)); + dynamicSizeProducers.push_back( + llvm::to_vector(state.getPayloadOps(transformValue))); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = @@ -2536,10 +2533,6 @@ TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); auto transformOp = cast(getOperation()); - ArrayRef targetsView = state.getPayloadOps(getTarget()); - // Store payload ops into a separate SmallVector because the TrackingListener - // removes erased ops from the transform state. - SmallVector targets(targetsView.begin(), targetsView.end()); // Result payload ops. SmallVector tileOps; @@ -2564,7 +2557,7 @@ if (!status.succeeded()) return status; - for (Operation *target : targets) { + for (Operation *target : state.getPayloadOps(getTarget())) { linalg::ForallTilingResult tilingResult; DiagnosedSilenceableFailure diag = tileToForallOpImpl( rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, @@ -2652,12 +2645,13 @@ TransformState &state) { ArrayRef tileSizes = getStaticSizes(); - ArrayRef targets = state.getPayloadOps(getTarget()); - SmallVector> dynamicSizeProducers; + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + SmallVector> dynamicSizeProducers; dynamicSizeProducers.reserve(getDynamicSizes().size()); for (Value dynamicSizeProducerHandle : getDynamicSizes()) { dynamicSizeProducers.push_back( - state.getPayloadOps(dynamicSizeProducerHandle)); + llvm::to_vector(state.getPayloadOps(dynamicSizeProducerHandle))); if (dynamicSizeProducers.back().size() != targets.size()) { DiagnosedSilenceableFailure diag = @@ -2884,8 +2878,8 @@ mlir::transform::TransformState &state) { TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); - ArrayRef targets = state.getPayloadOps(getTarget()); - if (targets.empty()) + auto targets = state.getPayloadOps(getTarget()); + if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); SmallVector vectorSizes; @@ -2896,16 +2890,16 @@ continue; } - ArrayRef szPayloads = state.getPayloadOps(sz.get()); - if (szPayloads.size() != 1) { + auto szPayloads = state.getPayloadOps(sz.get()); + if (!llvm::hasSingleElement(szPayloads)) { auto diag = this->emitOpError( "requires vector size handle that is mapped to 1 payload op"); diag.attachNote(sz.get().getLoc()) - << "mapped to " << szPayloads.size() << " payload ops"; + << "mapped to " << llvm::range_size(szPayloads) << " payload ops"; return DiagnosedSilenceableFailure::definiteFailure(); } - Operation *szPayloadOp = szPayloads[0]; + Operation *szPayloadOp = *szPayloads.begin(); if (szPayloadOp->getNumResults() != 1 || !szPayloadOp->getResult(0).getType().isIndex()) { auto diag = this->emitOpError( diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -35,9 +35,8 @@ transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); IRRewriter rewriter(getContext()); - for (auto *op : payloadOps) { + for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -53,7 +53,7 @@ Operation *transform::TransformState::getTopLevel() const { return topLevel; } ArrayRef -transform::TransformState::getPayloadOps(Value value) const { +transform::TransformState::getPayloadOpsView(Value value) const { const TransformOpMapping &operationMapping = getMapping(value).direct; auto iter = operationMapping.find(value); assert( @@ -357,18 +357,8 @@ // TODO: consider invalidating the handles to nested objects here. - // If replacing with null, that is erasing the mapping, drop the mapping - // between the handles and the IR objects and return. - if (!replacement) { - for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); - dropMappingEntry(mappings.direct, handle, op); - } - return success(); - } - #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (options.getExpensiveChecksEnabled()) { + if (replacement && options.getExpensiveChecksEnabled()) { auto insertion = cachedNames.insert({replacement, replacement->getName()}); if (!insertion.second) { assert(insertion.first->second == replacement->getName() && @@ -377,8 +367,15 @@ } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - // Otherwise, replace the pointed-to object of all handles while preserving - // their relative order. First, replace the mapped operation if present. + // Replace the pointed-to object of all handles with the replacement object. + // In case a payload op was erased (replacement object is nullptr), a nullptr + // is stored in the mapping. These nullptrs are removed after each transform. + // Furthermore, nullptrs are not enumerated by payload op iterators. The + // relative order of ops is preserved. + // + // Removing an op from the mapping would be problematic because removing an + // element from an array invalidates iterators; merely changing the value of + // elements does not. for (Value handle : opHandles) { Mappings &mappings = getMapping(handle); auto it = mappings.direct.find(handle); @@ -391,7 +388,12 @@ if (mapped == op) mapped = replacement; } - mappings.reverse[replacement].push_back(handle); + + if (replacement) { + mappings.reverse[replacement].push_back(handle); + } else { + opHandlesToCompact.insert(handle); + } } return success(); @@ -645,7 +647,7 @@ FULL_LDBG("----found consume effect -> SKIP\n"); if (llvm::isa(target.get().getType())) { FULL_LDBG("----recordOpHandleInvalidation\n"); - ArrayRef payloadOps = getPayloadOps(target.get()); + ArrayRef payloadOps = getPayloadOpsView(target.get()); recordOpHandleInvalidation(target, payloadOps); } else if (llvm::isa( target.get().getType())) { @@ -686,6 +688,14 @@ return DiagnosedSilenceableFailure::success(); } +void transform::TransformState::compactOpHandles() { + for (Value handle : opHandlesToCompact) { + Mappings &mappings = getMapping(handle); + llvm::erase_value(mappings.direct[handle], nullptr); + } + opHandlesToCompact.clear(); +} + DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG({ @@ -721,7 +731,7 @@ FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand( - getPayloadOps(operand.get()), transform, + getPayloadOpsView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { FULL_LDBG("----FAILED\n"); @@ -835,6 +845,7 @@ // proceed on a best effort basis. transform::TransformResults results(transform->getNumResults()); DiagnosedSilenceableFailure result(transform.apply(results, *this)); + compactOpHandles(); if (result.isDefiniteFailure()) return result; @@ -988,19 +999,6 @@ values.appendEmptyRows(numSegments); } -void transform::TransformResults::set(OpResult value, - ArrayRef ops) { - int64_t position = value.getResultNumber(); - assert(position < static_cast(operations.size()) && - "setting results for a non-existent handle"); - assert(operations[position].data() == nullptr && "results already set"); - assert(params[position].data() == nullptr && - "another kind of results already set"); - assert(values[position].data() == nullptr && - "another kind of results already set"); - operations.replace(position, ops); -} - void transform::TransformResults::setParams( OpResult value, ArrayRef params) { int64_t position = value.getResultNumber(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -767,10 +767,9 @@ DiagnosedSilenceableFailure transform::ForeachOp::apply(transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payloadOps = state.getPayloadOps(getTarget()); SmallVector> resultOps(getNumResults(), {}); - for (Operation *op : payloadOps) { + for (Operation *op : state.getPayloadOps(getTarget())) { auto scope = state.make_region_scope(getBody()); if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) return DiagnosedSilenceableFailure::definiteFailure(); @@ -785,8 +784,7 @@ // Append yielded payload ops to result list (if any). for (unsigned i = 0; i < getNumResults(); ++i) { - ArrayRef yieldedOps = - state.getPayloadOps(getYieldOp().getOperand(i)); + auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i)); resultOps[i].append(yieldedOps.begin(), yieldedOps.end()); } } @@ -882,16 +880,16 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - if (payloadOps.empty()) { - results.set(llvm::cast(getResult()), {}); + auto payloadOps = state.getPayloadOps(getTarget()); + if (std::empty(payloadOps)) { + results.set(cast(getResult()), {}); return DiagnosedSilenceableFailure::success(); } - if (payloadOps.size() != 1) + if (!llvm::hasSingleElement(payloadOps)) return emitDefiniteFailure() << "handle must be mapped to exactly one payload op"; - Operation *target = payloadOps.front(); + Operation *target = *payloadOps.begin(); if (target->getNumResults() <= resultNumber) return emitDefiniteFailure() << "result number overflow"; results.set(llvm::cast(getResult()), @@ -1483,7 +1481,7 @@ DiagnosedSilenceableFailure transform::SplitHandleOp::apply(transform::TransformResults &results, transform::TransformState &state) { - int64_t numPayloadOps = state.getPayloadOps(getHandle()).size(); + int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); auto produceNumOpsError = [&]() { return emitSilenceableError() << getHandle() << " expected to contain " << this->getNumResults() @@ -1573,11 +1571,12 @@ DiagnosedSilenceableFailure transform::ReplicateOp::apply(transform::TransformResults &results, transform::TransformState &state) { - unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); + unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); - if (llvm::isa(handle.getType())) { - ArrayRef current = state.getPayloadOps(handle); + if (isa(handle.getType())) { + SmallVector current = + llvm::to_vector(state.getPayloadOps(handle)); SmallVector payload; payload.reserve(numRepetitions * current.size()); for (unsigned i = 0; i < numRepetitions; ++i) @@ -2011,8 +2010,7 @@ } llvm::outs() << "]]]\n"; - ArrayRef targets = state.getPayloadOps(getTarget()); - for (Operation *target : targets) + for (Operation *target : state.getPayloadOps(getTarget())) llvm::outs() << *target << "\n"; return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): @@ -1598,3 +1598,26 @@ return } } + +// ----- + +// CHECK-LABEL: func @test_tracked_rewrite() { +// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} +// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"} +// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} +// CHECK-NEXT: } +func.func @test_tracked_rewrite() { + %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) + %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1) + %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!pdl.operation) -> !pdl.operation + // expected-remark @below {{2 iterations}} + transform.test_tracked_rewrite %0 : (!pdl.operation) -> () + // One replacement op (test.drop_mapping) is dropped from the mapping. + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_ops %0 +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -109,10 +110,10 @@ mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(llvm::cast(getResult()), - getOperation()->getOperand(0).getDefiningOp()); + results.set(cast(getResult()), + {getOperation()->getOperand(0).getDefiningOp()}); } else { - results.set(llvm::cast(getResult()), getOperation()); + results.set(cast(getResult()), {getOperation()}); } return DiagnosedSilenceableFailure::success(); } @@ -191,12 +192,13 @@ DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payload = state.getPayloadOps(getOperand()); - assert(payload.size() == 1 && "expected a single target op"); - if (payload[0]->getName().getStringRef() != getOpKind()) { + auto payload = state.getPayloadOps(getOperand()); + assert(llvm::hasSingleElement(payload) && "expected a single target op"); + if ((*payload.begin())->getName().getStringRef() != getOpKind()) { return emitSilenceableError() << "op expected the operand to be associated a payload op of kind " - << getOpKind() << " got " << payload[0]->getName().getStringRef(); + << getOpKind() << " got " + << (*payload.begin())->getName().getStringRef(); } emitRemark() << "succeeded"; @@ -230,7 +232,7 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payload = state.getPayloadOps(getOperand()); + auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) op->emitRemark() << getMessage(); @@ -313,11 +315,11 @@ if (!extension) return emitDefiniteFailure("TestTransformStateExtension missing"); - if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), - getOperation()))) + if (failed(extension->updateMapping( + *state.getPayloadOps(getOperand()).begin(), getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) - results.set(llvm::cast(getResult(0)), getOperation()); + results.set(cast(getResult(0)), {getOperation()}); return DiagnosedSilenceableFailure::success(); } @@ -337,7 +339,7 @@ DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, transform::TransformState &state) { - ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); @@ -431,7 +433,7 @@ transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; - emitRemark() << state.getPayloadOps(getHandle()).size(); + emitRemark() << llvm::range_size(state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } @@ -599,11 +601,11 @@ } else if (getFirstResultIsNull()) { results.push_back(nullptr); } else { - results.push_back(state.getPayloadOps(getIn()).front()); + results.push_back(*state.getPayloadOps(getIn()).begin()); } if (getSecondResultIsHandle()) { - results.push_back(state.getPayloadOps(getIn()).front()); + results.push_back(*state.getPayloadOps(getIn()).begin()); } else { results.push_back(builder.getI64IntegerAttr(42)); } @@ -667,6 +669,70 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestTrackedRewriteOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::modifiesPayload(effects); +} + +namespace { +/// A TrackingListener for test cases. When the replacement op is +/// "test.update_mapping", it is considered as a replacement op in the transform +/// state mapping. Otherwise, it is not and the original op is simply removed +/// from the mapping. +class TestTrackingListener : public transform::TrackingListener { + using transform::TrackingListener::TrackingListener; + +protected: + Operation *findReplacementOp(Operation *op, + ValueRange newValues) const override { + if (newValues.size() != 1) + return nullptr; + Operation *replacement = newValues[0].getDefiningOp(); + if (!replacement) + return nullptr; + if (replacement->getName().getStringRef() != "test.update_mapping") + return nullptr; + return replacement; + } +}; +} // namespace + +DiagnosedSilenceableFailure +mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + TestTrackingListener listener(state, *this); + IRRewriter rewriter(getContext(), &listener); + int64_t numIterations = 0; + + // `getPayloadOps` returns an iterator that skips ops that are erased in the + // loop body. Replacement ops are not enumerated. + for (Operation *op : state.getPayloadOps(getIn())) { + ++numIterations; + rewriter.setInsertionPointToEnd(op->getBlock()); + + // Erase all payload ops. The outer loop should have only one iteration. + for (Operation *op : state.getPayloadOps(getIn())) { + if (op->getName().getStringRef() != "test.replace_me") + continue; + auto replacementName = op->getAttrOfType("replacement"); + if (!replacementName) + continue; + SmallVector attributes; + attributes.emplace_back(rewriter.getStringAttr("original_op"), + op->getName().getIdentifier()); + OperationState opState(op->getLoc(), replacementName, + /*operands=*/ValueRange(), + /*types=*/op->getResultTypes(), attributes); + Operation *newOp = rewriter.create(opState); + rewriter.replaceOp(op, newOp->getResults()); + } + } + + emitRemark() << numIterations << " iterations"; + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -449,4 +449,14 @@ let cppNamespace = "::mlir::test"; } +def TestTrackedRewriteOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$in); + let results = (outs); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD