diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -256,6 +256,12 @@ template bool hasSingleEffect(Operation *op, Value value = nullptr); +/// Returns true if `op` has an effect of type `EffectTy` on `value`. If no +/// `value` is provided, simply check if effects of the given type(s) are +/// present. +template +bool hasEffect(Operation *op, Value value = nullptr); + /// Return true if the given operation is unused, and has no side effects on /// memory that prevent erasing. bool isOpTriviallyDead(Operation *op); diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -803,18 +803,11 @@ Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } - } else if (auto effectInterface = dyn_cast(op)) { + } else if (hasEffect(&op)) { // Create graph node for top-level op, which could have a memory write // side effect. - SmallVector effects; - effectInterface.getEffects(effects); - if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &it) { - return isa( - it.getEffect()); - })) { - Node node(nextNodeId++, &op); - nodes.insert({node.id, node}); - } + Node node(nextNodeId++, &op); + nodes.insert({node.id, node}); } } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -42,22 +42,11 @@ addInterfaces(); } -/// Finds a single dealloc operation for the given allocated value. +/// Finds the unique dealloc operation (if one exists) for `allocValue`. llvm::Optional mlir::memref::findDealloc(Value allocValue) { Operation *dealloc = nullptr; for (Operation *user : allocValue.getUsers()) { - auto effectInterface = dyn_cast(user); - if (!effectInterface) - continue; - // Try to find a free effect that is applied to one of our values - // that will be automatically freed by our pass. - SmallVector effects; - effectInterface.getEffectsOnValue(allocValue, effects); - const bool isFree = - llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { - return isa(it.getEffect()); - }); - if (!isFree) + if (!hasEffect(user, allocValue)) continue; // If we found > 1 dealloc, return None. if (dealloc) diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -374,8 +374,8 @@ auto iface = cast(transform.getOperation()); SmallVector effects; iface.getEffectsOnValue(handle, effects); - return hasEffect(effects) && - hasEffect(effects); + return ::hasEffect(effects) && + ::hasEffect(effects); } void transform::producesHandle( diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -128,6 +128,26 @@ template bool mlir::hasSingleEffect(Operation *, Value); template bool mlir::hasSingleEffect(Operation *, Value); +template +bool mlir::hasEffect(Operation *op, Value value) { + auto memOp = dyn_cast(op); + if (!memOp) + return false; + SmallVector, 4> effects; + memOp.getEffects(effects); + return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) { + if (value && effect.getValue() != value) + return false; + return isa(effect.getEffect()); + }); +} +template bool mlir::hasEffect(Operation *, Value); +template bool mlir::hasEffect(Operation *, Value); +template bool mlir::hasEffect(Operation *, Value); +template bool mlir::hasEffect(Operation *, Value); +template bool +mlir::hasEffect(Operation *, Value); + bool mlir::wouldOpBeTriviallyDead(Operation *op) { if (op->mightHaveTrait()) return false;