diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -695,6 +695,26 @@ StringRef getName() override { return "transform.payload_ir"; } }; +/// Populates `effects` with the memory effects indicating the operation on the +/// given handle value: +/// - consumes = Read + Free, +/// - produces = Allocate + Write, +/// - onlyReads = Read. +void consumesHandle(ValueRange handles, + SmallVectorImpl &effects); +void producesHandle(ValueRange handles, + SmallVectorImpl &effects); +void onlyReadsHandle(ValueRange handles, + SmallVectorImpl &effects); + +/// Checks whether the transform op consumes the given handle. +bool isHandleConsumed(Value handle, transform::TransformOpInterface transform); + +/// Populates `effects` with the memory effects indicating the access to payload +/// IR resource. +void modifiesPayload(SmallVectorImpl &effects); +void onlyReadsPayload(SmallVectorImpl &effects); + /// Trait implementing the MemoryEffectOpInterface for operations that "consume" /// their operands and produce new results. template @@ -705,20 +725,9 @@ /// the results by allocating and writing it and reads/writes the payload IR /// in the process. void getEffects(SmallVectorImpl &effects) { - for (Value operand : this->getOperation()->getOperands()) { - effects.emplace_back(MemoryEffects::Read::get(), operand, - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Free::get(), operand, - TransformMappingResource::get()); - } - for (Value result : this->getOperation()->getResults()) { - effects.emplace_back(MemoryEffects::Allocate::get(), result, - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), result, - TransformMappingResource::get()); - } - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); + consumesHandle(this->getOperation()->getOperands(), effects); + producesHandle(this->getOperation()->getResults(), effects); + modifiesPayload(effects); } /// Checks that the op matches the expectations of this trait. @@ -742,16 +751,9 @@ /// This op produces handles to the Payload IR without consuming the original /// handles and without modifying the IR itself. void getEffects(SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), - this->getOperation()->getOperand(0), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Allocate::get(), - this->getOperation()->getResult(0), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), - this->getOperation()->getResult(0), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + onlyReadsHandle(this->getOperation()->getOperands(), effects); + producesHandle(this->getOperation()->getResults(), effects); + onlyReadsPayload(effects); } /// Checks that the op matches the expectation of this trait. @@ -845,27 +847,6 @@ return res; } } // namespace detail - -/// Populates `effects` with the memory effects indicating the operation on the -/// given handle value: -/// - consumes = Read + Free, -/// - produces = Allocate + Write, -/// - onlyReads = Read. -void consumesHandle(ValueRange handles, - SmallVectorImpl &effects); -void producesHandle(ValueRange handles, - SmallVectorImpl &effects); -void onlyReadsHandle(ValueRange handles, - SmallVectorImpl &effects); - -/// Checks whether the transform op consumes the given handle. -bool isHandleConsumed(Value handle, transform::TransformOpInterface transform); - -/// Populates `effects` with the memory effects indicating the access to payload -/// IR resource. -void modifiesPayload(SmallVectorImpl &effects); -void onlyReadsPayload(SmallVectorImpl &effects); - } // namespace transform } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -312,19 +312,9 @@ void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - transform::TransformMappingResource::get()); - for (Value result : getResults()) { - effects.emplace_back(MemoryEffects::Allocate::get(), result, - transform::TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), result, - transform::TransformMappingResource::get()); - } - - effects.emplace_back(MemoryEffects::Read::get(), - transform::PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), - transform::PayloadIRResource::get()); + onlyReadsHandle(getTarget(), effects); + producesHandle(getResults(), effects); + modifiesPayload(effects); } //===---------------------------------------------------------------------===// @@ -527,28 +517,11 @@ void SplitOp::getEffects( SmallVectorImpl &effects) { - // The target handle is consumed. - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Free::get(), getTarget(), - TransformMappingResource::get()); - - // The dynamic split point handle is not consumed. - if (getDynamicSplitPoint()) { - effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(), - TransformMappingResource::get()); - } - - // The resulting handles are produced. - for (Value result : getResults()) { - effects.emplace_back(MemoryEffects::Allocate::get(), result, - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), result, - TransformMappingResource::get()); - } - - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); + consumesHandle(getTarget(), effects); + if (getDynamicSplitPoint()) + onlyReadsHandle(getDynamicSplitPoint(), effects); + producesHandle(getResults(), effects); + modifiesPayload(effects); } ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -361,10 +361,11 @@ /// Returns `true` if the given list of effects instances contains an instance /// with the effect type specified as template parameter. -template +template static bool hasEffect(ArrayRef effects) { return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()); + return isa(effect.getEffect()) && + isa(effect.getResource()); }); } @@ -373,8 +374,8 @@ auto iface = cast(transform.getOperation()); SmallVector effects; iface.getEffectsOnValue(handle, effects); - return hasEffect(effects) && - hasEffect(effects); + return hasEffect(effects) && + hasEffect(effects); } void transform::producesHandle( diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -318,16 +318,8 @@ void transform::MergeHandlesOp::getEffects( SmallVectorImpl &effects) { - for (Value operand : getHandles()) { - effects.emplace_back(MemoryEffects::Read::get(), operand, - transform::TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Free::get(), operand, - transform::TransformMappingResource::get()); - } - effects.emplace_back(MemoryEffects::Allocate::get(), getResult(), - transform::TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), getResult(), - transform::TransformMappingResource::get()); + consumesHandle(getHandles(), effects); + producesHandle(getResult(), effects); // There are no effects on the Payload IR as this is only a handle // manipulation. @@ -421,16 +413,11 @@ /// the Transform IR. That is, if it may have a Free effect on it. static bool isValueUsePotentialConsumer(OpOperand &use) { // Conservatively assume the effect being present in absence of the interface. - auto memEffectInterface = dyn_cast(use.getOwner()); - if (!memEffectInterface) + auto iface = dyn_cast(use.getOwner()); + if (!iface) return true; - SmallVector effects; - memEffectInterface.getEffectsOnValue(use.get(), effects); - return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getResource()) && - isa(effect.getEffect()); - }); + return isHandleConsumed(use.get(), iface); } LogicalResult diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -48,7 +48,7 @@ [DeclareOpInterfaceMethods]> { let arguments = (ins Arg:$operand, + [TransformMappingRead, TransformMappingFree]>:$operand, I64Attr:$parameter); let assemblyFormat = "$operand `[` $parameter `]` attr-dict"; let cppNamespace = "::mlir::test";