diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -14,6 +14,38 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def EliminateBarriersOp : + Op]> { + let description = [{ + Removes unnecessary GPU barriers from the function. If a barrier does not + enforce any conflicting pair of memory effects, including a pair that is + enforced by another barrier, it is unnecessary and can be removed. + + The approach is based on "High-Performance GPU-to-CPU Transpilation and + Optimization via High-Level Parallel Constructs" by Moses, Ivanov, + Domke, Endo, Doerfert, and Zinenko in PPoPP 2023. Specifically, it + analyzes the memory effects of the operations before and after the given + barrier and checks if the barrier enforces any of the memory + effect-induced dependencies that aren't already enforced by another + barrier. + + For example, in the following code + + ```mlir + store %A + barrier // enforces load-after-store + load %A + barrier // load-after-store already enforced by the previous barrier + load %A + ``` + + the second barrier can be removed. + }]; + + let assemblyFormat = [{ attr-dict }]; +} + def MapNestedForallToThreads : Op(op); +} + +/// Returns `true` if the op is defines the parallel region that is subject to +/// barrier synchronization. +static bool isParallelRegionBoundary(Operation *op) { + if (op->hasAttr("__parallel_region_boundary_for_test")) + return true; + + return isa(op); +} + +/// Returns `true` if the op behaves like a sequential loop, e.g., the control +/// flow "wraps around" from the end of the body region back to its start. +static bool isSequentialLoopLike(Operation *op) { return isa(op); } + +/// Returns `true` if the regions of the op are guaranteed to be executed at +/// most once. Thus, if an operation in one of the nested regions of `op` is +/// executed than so are all the other operations in this region. +static bool hasSingleExecutionBody(Operation *op) { + return isa(op); +} + +/// Returns `true` if the operation is known to produce a pointer-like object +/// distinct from any other object produced by a similar operation. For example, +/// an allocation produces such an object. +static bool producesDistinctBase(Operation *op) { + return isa_and_nonnull(op); +} + +/// Populates `effects` with all memory effects without associating them to a +/// specific value. +static void addAllValuelessEffects( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); + effects.emplace_back(MemoryEffects::Effect::get()); +} + +/// Collect the memory effects of the given op in 'effects'. Returns 'true' if +/// it could extract the effect information from the op, otherwise returns +/// 'false' and conservatively populates the list with all possible effects +/// associated with no particular value or symbol. +static bool +collectEffects(Operation *op, + SmallVectorImpl &effects, + bool ignoreBarriers = true) { + // Skip over barriers to avoid infinite recursion (those barriers would ask + // this barrier again). + if (ignoreBarriers && isa(op)) + return true; + + // Skip over ops that we know have no effects. + if (isKnownNoEffectsOpWithoutInterface(op)) + return true; + + // Collect effect instances the operation. Note that the implementation of + // getEffects erases all effect instances that have the type other than the + // template parameter so we collect them first in a local buffer and then + // copy. + if (auto iface = dyn_cast(op)) { + SmallVector localEffects; + iface.getEffects(localEffects); + llvm::append_range(effects, localEffects); + return true; + } + if (op->hasTrait()) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &innerOp : block) + if (!collectEffects(&innerOp, effects, ignoreBarriers)) + return false; + } + } + return true; + } + + // We need to be conservative here in case the op doesn't have the interface + // and assume it can have any possible effect. + addAllValuelessEffects(effects); + return false; +} + +/// Collects memory effects from operations that may be executed before `op` in +/// a trivial structured control flow, e.g., without branches. Stops at the +/// parallel region boundary or at the barrier operation if `stopAtBarrier` is +/// set. Returns `true` if the memory effects added to `effects` are exact, +/// `false` if they are a conservative over-approximation. The latter means that +/// `effects` contain instances not associated with a specific value. +bool getEffectsBefore(Operation *op, + SmallVectorImpl &effects, + bool stopAtBarrier) { + if (!op->getBlock()) + return true; + + // If there is a non-structured control flow, bail. + Region *region = op->getBlock()->getParent(); + if (region && !llvm::hasSingleElement(region->getBlocks())) { + addAllValuelessEffects(effects); + return false; + } + + // Collect all effects before the op. + if (op != &op->getBlock()->front()) { + for (Operation *it = op->getPrevNode(); it != nullptr; + it = it->getPrevNode()) { + if (isa(it)) { + if (stopAtBarrier) + return true; + else + continue; + } + if (!collectEffects(it, effects)) + return false; + } + } + + // Stop if reached the parallel region boundary. + if (isParallelRegionBoundary(op->getParentOp())) + return true; + + // Otherwise, keep collecting above the parent operation. + if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier)) + return false; + + // If the op is loop-like, collect effects from the trailing operations until + // we hit a barrier because they can executed before the current operation by + // the previous iteration of this loop. For example, in the following loop + // + // for i = ... { + // op1 + // ... + // barrier + // op2 + // } + // + // the operation `op2` at iteration `i` is known to be executed before the + // operation `op1` at iteration `i+1` and the side effects must be ordered + // appropriately. + if (isSequentialLoopLike(op->getParentOp())) { + // Assuming loop terminators have no side effects. + return getEffectsBefore(op->getBlock()->getTerminator(), effects, + /*stopAtBarrier=*/true); + } + + // If the parent operation is not guaranteed to execute its (single-block) + // region once, walk the block. + bool conservative = false; + if (!hasSingleExecutionBody(op->getParentOp())) + op->getParentOp()->walk([&](Operation *in) { + if (conservative) + return WalkResult::interrupt(); + if (!collectEffects(in, effects)) { + conservative = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return !conservative; +} + +/// Collects memory effects from operations that may be executed after `op` in +/// a trivial structured control flow, e.g., without branches. Stops at the +/// parallel region boundary or at the barrier operation if `stopAtBarrier` is +/// set. Returns `true` if the memory effects added to `effects` are exact, +/// `false` if they are a conservative over-approximation. The latter means that +/// `effects` contain instances not associated with a specific value. +bool getEffectsAfter(Operation *op, + SmallVectorImpl &effects, + bool stopAtBarrier) { + if (!op->getBlock()) + return true; + + // If there is a non-structured control flow, bail. + Region *region = op->getBlock()->getParent(); + if (region && !llvm::hasSingleElement(region->getBlocks())) { + addAllValuelessEffects(effects); + return false; + } + + // Collect all effects after the op. + if (op != &op->getBlock()->back()) + for (Operation *it = op->getNextNode(); it != nullptr; + it = it->getNextNode()) { + if (isa(it)) { + if (stopAtBarrier) + return true; + continue; + } + if (!collectEffects(it, effects)) + return false; + } + + // Stop if reached the parallel region boundary. + if (isParallelRegionBoundary(op->getParentOp())) + return true; + + // Otherwise, keep collecting below the parent operation. + if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier)) + return false; + + // If the op is loop-like, collect effects from the leading operations until + // we hit a barrier because they can executed after the current operation by + // the next iteration of this loop. For example, in the following loop + // + // for i = ... { + // op1 + // ... + // barrier + // op2 + // } + // + // the operation `op1` at iteration `i` is known to be executed after the + // operation `op2` at iteration `i-1` and the side effects must be ordered + // appropriately. + if (isSequentialLoopLike(op->getParentOp())) { + if (isa(op->getBlock()->front())) + return true; + + bool exact = collectEffects(&op->getBlock()->front(), effects); + return getEffectsAfter(&op->getBlock()->front(), effects, + /*stopAtBarrier=*/true) && + exact; + } + + // If the parent operation is not guaranteed to execute its (single-block) + // region once, walk the block. + bool conservative = false; + if (!hasSingleExecutionBody(op->getParentOp())) + op->getParentOp()->walk([&](Operation *in) { + if (conservative) + return WalkResult::interrupt(); + if (!collectEffects(in, effects)) { + conservative = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return !conservative; +} + +/// Looks through known "view-like" ops to find the base memref. +static Value getBase(Value v) { + while (true) { + Operation *definingOp = v.getDefiningOp(); + if (!definingOp) + break; + + bool shouldContinue = + TypeSwitch(v.getDefiningOp()) + .Case( + [&](auto op) { + v = op.getSource(); + return true; + }) + .Case([&](auto op) { + v = op.getIn(); + return true; + }) + .Case([&](auto op) { + v = op.getSrc(); + return true; + }) + .Default([](Operation *) { return false; }); + if (!shouldContinue) + break; + } + return v; +} + +/// Returns `true` if the value is defined as a function argument. +static bool isFunctionArgument(Value v) { + auto arg = dyn_cast(v); + return arg && isa(arg.getOwner()->getParentOp()); +} + +/// Returns the operand that the operation "propagates" through it for capture +/// purposes. That is, if the value produced by this operation is captured, then +/// so is the returned value. +static Value propagatesCapture(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](ViewLikeOpInterface viewLike) { return viewLike.getViewSource(); }) + .Case([](CastOpInterface castLike) { return castLike->getOperand(0); }) + .Case([](memref::TransposeOp transpose) { return transpose.getIn(); }) + .Case( + [](auto op) { return op.getSrc(); }) + .Default([](Operation *) { return Value(); }); +} + +/// Returns `true` if the given operation is known to capture the given value, +/// `false` if it is known not to capture the given value, `nullopt` if neither +/// is known. +static std::optional getKnownCapturingStatus(Operation *op, Value v) { + return llvm::TypeSwitch>(op) + // Store-like operations don't capture the destination, but do capture + // the value. + .Case( + [&](auto op) { return op.getValue() == v; }) + .Case( + [&](auto op) { return op.getValueToStore() == v; }) + // These operations are known not to capture. + .Case([](memref::DeallocOp) { return false; }) + // By default, we don't know anything. + .Default([](Operation *) { return std::nullopt; }); +} + +/// Returns `true` if the value may be captured by any of its users, i.e., if +/// the user may be storing this value into memory. This makes aliasing analysis +/// more conservative as it cannot assume the pointer-like value is only passed +/// around through SSA use-def. +bool maybeCaptured(Value v) { + SmallVector todo = {v}; + while (!todo.empty()) { + Value v = todo.pop_back_val(); + for (Operation *user : v.getUsers()) { + // A user that is known to only read cannot capture. + auto iface = dyn_cast(user); + if (iface) { + SmallVector effects; + iface.getEffects(effects); + if (llvm::all_of(effects, + [](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()); + })) { + continue; + } + } + + // When an operation is known to create an alias, consider if the + // source is captured as well. + if (Value v = propagatesCapture(user)) { + todo.push_back(v); + continue; + } + + std::optional knownCaptureStatus = getKnownCapturingStatus(user, v); + if (!knownCaptureStatus || *knownCaptureStatus) + return true; + } + } + + return false; +} + +/// Returns true if two values may be referencing aliasing memory. This is a +/// rather naive and conservative analysis. Values defined by different +/// allocation-like operations as well as values derived from those by casts and +/// views cannot alias each other. Similarly, values defined by allocations +/// inside a function cannot alias function arguments. Global values cannot +/// alias each other or local allocations. Values that are captured, i.e. +/// themselves potentially stored in memory, are considered as aliasing with +/// everything. This seems sufficient to achieve barrier removal in structured +/// control flow, more complex cases would require a proper dataflow analysis. +static bool mayAlias(Value first, Value second) { + DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { + DBGS_ALIAS() << "checking aliasing between "; + DBGS_ALIAS() << first << "\n"; + DBGS_ALIAS() << " and "; + DBGS_ALIAS() << second << "\n"; + }); + + first = getBase(first); + second = getBase(second); + + DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { + DBGS_ALIAS() << "base "; + DBGS_ALIAS() << first << "\n"; + DBGS_ALIAS() << " and "; + DBGS_ALIAS() << second << "\n"; + }); + + // Values derived from the same base memref do alias (unless we do a more + // advanced analysis to prove non-overlapping accesses). + if (first == second) { + DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n"); + return true; + } + + // Different globals cannot alias. + if (auto globFirst = first.getDefiningOp()) { + if (auto globSecond = second.getDefiningOp()) { + return globFirst.getNameAttr() == globSecond.getNameAttr(); + } + } + + // Two function arguments marked as noalias do not alias. + auto isNoaliasFuncArgument = [](Value value) { + auto bbArg = dyn_cast(value); + if (!bbArg) + return false; + auto iface = dyn_cast(bbArg.getOwner()->getParentOp()); + if (!iface) + return false; + // TODO: we need a way to not depend on the LLVM dialect here. + return iface.getArgAttr(bbArg.getArgNumber(), "llvm.noalias") != nullptr; + }; + if (isNoaliasFuncArgument(first) && isNoaliasFuncArgument(second)) + return false; + + bool isDistinct[] = {producesDistinctBase(first.getDefiningOp()), + producesDistinctBase(second.getDefiningOp())}; + bool isGlobal[] = {first.getDefiningOp() != nullptr, + second.getDefiningOp() != nullptr}; + + // Non-equivalent distinct bases and globals cannot alias. At this point, we + // have already filtered out based on values being equal and global name being + // equal. + if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1])) + return false; + + bool isArg[] = {isFunctionArgument(first), isFunctionArgument(second)}; + + // Distinct bases (allocations) cannot have been passed as an argument. + if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0])) + return false; + + // Non-captured base distinct values cannot conflict with another base value. + if (isDistinct[0] && !maybeCaptured(first)) + return false; + if (isDistinct[1] && !maybeCaptured(second)) + return false; + + // Otherwise, conservatively assume aliasing. + DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n"); + return true; +} + +/// Returns `true` if the effect may be affecting memory aliasing the value. If +/// the effect is not associated with any value, it is assumed to affect all +/// memory and therefore aliases with everything. +bool mayAlias(MemoryEffects::EffectInstance a, Value v2) { + if (Value v = a.getValue()) { + return mayAlias(v, v2); + } + return true; +} + +/// Returns `true` if the two effects may be affecting aliasing memory. If +/// an effect is not associated with any value, it is assumed to affect all +/// memory and therefore aliases with everything. Effects on different resources +/// cannot alias. +bool mayAlias(MemoryEffects::EffectInstance a, + MemoryEffects::EffectInstance b) { + if (a.getResource()->getResourceID() != b.getResource()->getResourceID()) + return false; + if (Value v2 = b.getValue()) { + return mayAlias(a, v2); + } else if (Value v = a.getValue()) { + return mayAlias(b, v); + } + return true; +} + +/// Returns `true` if any of the "before" effect instances has a conflict with +/// any "after" instance for the purpose of barrier elimination. The effects are +/// supposed to be limited to a barrier synchronization scope. A conflict exists +/// if effects instances affect aliasing memory locations and at least on of +/// then as a write. As an exception, if the non-write effect is an allocation +/// effect, there is no conflict since we are only expected to see the +/// allocation happening in the same thread and it cannot be accessed from +/// another thread without capture (which we do handle in alias analysis). +static bool +haveConflictingEffects(ArrayRef beforeEffects, + ArrayRef afterEffects) { + for (const MemoryEffects::EffectInstance &before : beforeEffects) { + for (const MemoryEffects::EffectInstance &after : afterEffects) { + // If cannot alias, definitely no conflict. + if (!mayAlias(before, after)) + continue; + + // Read/read is not a conflict. + if (isa(before.getEffect()) && + isa(after.getEffect())) { + continue; + } + + // Allocate/* is not a conflict since the allocation happens within the + // thread context. + // TODO: This is not the case for */Free unless the allocation happened in + // the thread context, which we could also check for. + if (isa(before.getEffect()) || + isa(after.getEffect())) { + continue; + } + + // In the particular case that the before effect is a free, we only have 2 + // possibilities: + // 1. either the program is well-formed and there must be an interleaved + // alloc that must limit the scope of effect lookback and we can + // safely ignore the free -> read / free -> write and free -> free + // conflicts. + // 2. either the program is ill-formed and we are in undefined behavior + // territory. + if (isa(before.getEffect())) + continue; + + // Other kinds of effects create a conflict, e.g. read-after-write. + LLVM_DEBUG( + DBGS() << "found a conflict between (before): " << before.getValue() + << " read:" << isa(before.getEffect()) + << " write:" << isa(before.getEffect()) + << " alloc:" + << isa(before.getEffect()) << " free:" + << isa(before.getEffect()) << "\n"); + LLVM_DEBUG( + DBGS() << "and (after): " << after.getValue() + << " read:" << isa(after.getEffect()) + << " write:" << isa(after.getEffect()) + << " alloc:" << isa(after.getEffect()) + << " free:" << isa(after.getEffect()) + << "\n"); + return true; + } + } + + return false; +} + +namespace { +/// Barrier elimination pattern. If a barrier does not enforce any conflicting +/// pair of memory effects, including a pair that is enforced by another +/// barrier, it is unnecessary and can be removed. Adapted from +/// "High-Performance GPU-to-CPU Transpilation and Optimization via High-Level +/// Parallel Constructs" by Moses, Ivanov, Domke, Endo, Doerfert, and Zinenko in +/// PPoPP 2023 and implementation in Polygeist. +class BarrierElimination final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BarrierOp barrier, + PatternRewriter &rewriter) const override { + LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " " + << barrier.getLoc() << "\n"); + + SmallVector beforeEffects; + getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true); + + SmallVector afterEffects; + getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true); + + if (!haveConflictingEffects(beforeEffects, afterEffects)) { + LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing " + << barrier << "\n"); + rewriter.eraseOp(barrier); + return success(); + } + + LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " " + << barrier.getLoc() << "\n"); + return failure(); + } +}; +} // namespace + +void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { + patterns.insert(getContext()); +} + +//===----------------------------------------------------------------------===// +// Block and thread mapping utilities. +//===----------------------------------------------------------------------===// namespace { diff --git a/mlir/test/Dialect/GPU/barrier-elimination.mlir b/mlir/test/Dialect/GPU/barrier-elimination.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/barrier-elimination.mlir @@ -0,0 +1,181 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.gpu.eliminate_barriers + } : !transform.any_op +} + +// CHECK-LABEL: @read_read_write +func.func @read_read_write(%arg0: memref, %arg1: index) attributes {__parallel_region_boundary_for_test} { + // CHECK: load + %0 = memref.load %arg0[%arg1] : memref + // The barrier between loads can be removed. + // CHECK-NOT: barrier + gpu.barrier + // CHECK: load + %1 = memref.load %arg0[%arg1] : memref + %2 = arith.addf %0, %1 : f32 + // The barrier between load and store cannot be removed (unless we reason about accessed subsets). + // CHECK: barrier + gpu.barrier + // CHECK: store + memref.store %2, %arg0[%arg1] : memref + return +} + +// CHECK-LABEL: @write_read_read +func.func @write_read_read(%arg0: memref, %arg1: index, %arg2: f32) -> f32 +attributes {__parallel_region_boundary_for_test} { + // CHECK: store + memref.store %arg2, %arg0[%arg1] : memref + // The barrier between load and store cannot be removed (unless we reason about accessed subsets). + // CHECK: barrier + gpu.barrier + // CHECK: load + %0 = memref.load %arg0[%arg1] : memref + // CHECK-NOT: barrier + gpu.barrier + // CHECK: load + %1 = memref.load %arg0[%arg1] : memref + %2 = arith.addf %0, %1 : f32 + return %2 : f32 +} + +// CHECK-LABEL: @write_in_a_loop +func.func @write_in_a_loop(%arg0: memref, %arg1: f32) attributes {__parallel_region_boundary_for_test} { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %c42 step %c1 { + memref.store %arg1, %arg0[%i] : memref + // Cannot remove this barrier because it guards write-after-write between different iterations. + // CHECK: barrier + gpu.barrier + } + return +} + +// CHECK-LABEL @read_read_write_loop +func.func @read_read_write_loop(%arg0: memref, %arg1: f32) attributes {__parallel_region_boundary_for_test} { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %c42 step %c1 { + // (Note that if subscript were different, this would have been a race with the store at the end of the loop). + %0 = memref.load %arg0[%i] : memref + // Guards read-after-write where the write happens on the previous iteration. + // CHECK: barrier + gpu.barrier + %1 = memref.load %arg0[%i] : memref + %2 = arith.addf %0, %1 : f32 + // Guards write-after-read. + // CHECK: barrier + gpu.barrier + memref.store %2, %arg0[%i] : memref + } + return +} + +// CHECK-LABEL: @read_read_write_loop_trailing_sync +func.func @read_read_write_loop_trailing_sync(%arg0: memref, %arg1: f32) attributes {__parallel_region_boundary_for_test} { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %c42 step %c1 { + // CHECK: load + %0 = memref.load %arg0[%i] : memref + // This can be removed because it only guards a read-after-read. + // CHECK-NOT: barrier + gpu.barrier + // CHECK: load + %1 = memref.load %arg0[%i] : memref + %2 = arith.addf %0, %1 : f32 + // CHECK: barrier + gpu.barrier + // CHECK: store + memref.store %2, %arg0[%i] : memref + // CHECK: barrier + gpu.barrier + } + return +} + +// CHECK-LABEL: @write_write_noalias +func.func @write_write_noalias(%arg0: index, %arg1: f32) -> (memref<42xf32>, memref<10xf32>) +attributes {__parallel_region_boundary_for_test} { + %0 = memref.alloc() : memref<42xf32> + %1 = memref.alloc() : memref<10xf32> + // CHECK: store + memref.store %arg1, %0[%arg0] : memref<42xf32> + // This can be removed because we can prove two allocations don't alias. + // CHECK-NOT: barrier + gpu.barrier + // CHECK: store + memref.store %arg1, %1[%arg0] : memref<10xf32> + return %0, %1 : memref<42xf32>, memref<10xf32> +} + +// CHECK-LABEL: @write_write_alloc_arg_noalias +func.func @write_write_alloc_arg_noalias(%arg0: index, %arg1: f32, %arg2: memref) -> (memref<42xf32>) +attributes {__parallel_region_boundary_for_test} { + %0 = memref.alloc() : memref<42xf32> + // CHECK: store + memref.store %arg1, %0[%arg0] : memref<42xf32> + // This can be removed because we can prove local allocation doesn't alias with a function argument. + // CHECK-NOT: barrier + gpu.barrier + // CHECK: store + memref.store %arg1, %arg2[%arg0] : memref + return %0 : memref<42xf32> +} + +// CHECK-LABEL: @repeated_barrier +func.func @repeated_barrier(%arg0: memref, %arg1: index, %arg2: f32) -> f32 +attributes {__parallel_region_boundary_for_test} { + %0 = memref.load %arg0[%arg1] : memref + // CHECK: gpu.barrier + gpu.barrier + // CHECK-NOT: gpu.barrier + gpu.barrier + memref.store %arg2, %arg0[%arg1] : memref + return %0 : f32 +} + +// CHECK-LABEL: @symmetric_stop +func.func @symmetric_stop(%val: f32) -> (f32, f32, f32, f32, f32) +attributes {__parallel_region_boundary_for_test} { + // CHECK: %[[A:.+]] = memref.alloc + // CHECK: %[[B:.+]] = memref.alloc + // CHECK: %[[C:.+]] = memref.alloc + %A = memref.alloc() : memref + %B = memref.alloc() : memref + %C = memref.alloc() : memref + // CHECK: memref.store %{{.*}}, %[[A]] + memref.store %val, %A[] : memref + // CHECK: gpu.barrier + gpu.barrier + // CHECK: memref.load %[[A]] + %0 = memref.load %A[] : memref + // CHECK: memref.store %{{.*}}, %[[B]] + memref.store %val, %B[] : memref + // This barrier is eliminated because the surrounding barriers are sufficient + // to guard write/read on all memrefs. + // CHECK-NOT: gpu.barrier + gpu.barrier + // CHECK: memref.load %[[A]] + %1 = memref.load %A[] : memref + // CHECK: memref.store %{{.*}} %[[C]] + memref.store %val, %C[] : memref + // CHECK: gpu.barrier + gpu.barrier + // CHECK: memref.load %[[A]] + // CHECK: memref.load %[[B]] + // CHECK: memref.load %[[C]] + %2 = memref.load %A[] : memref + %3 = memref.load %B[] : memref + %4 = memref.load %C[] : memref + return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32 +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4689,12 +4689,14 @@ ":GPUTransformOpsIncGen", ":GPUTransforms", ":IR", + ":MemRefDialect", ":Parser", ":SCFDialect", ":SideEffectInterfaces", ":Support", ":TransformDialect", ":TransformUtils", + ":VectorDialect", "//llvm:Support", ], )