diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -157,10 +157,12 @@ /// Returns an accumulator factory using either the op attribute or the body /// region. AccumulatorFactory getFactory() { - if (!reduceOp.body().empty()) - return getFactory(reduceOp.body()); - if (reduceOp.op()) - return getFactory(*reduceOp.op()); + auto body = reduceOp.body(); + if (!body.empty()) + return getFactory(body); + auto opAttr = reduceOp.op(); + if (opAttr) + return getFactory(*opAttr); return AccumulatorFactory(); } @@ -186,7 +188,7 @@ // Replace all gpu.yield ops with branch out of body. for (; block != split; block = block->getNextNode()) { Operation *terminator = block->getTerminator(); - if (!llvm::isa(terminator)) + if (!isa(terminator)) continue; rewriter.setInsertionPointToEnd(block); rewriter.replaceOpWithNewOp( @@ -297,13 +299,13 @@ createIf( shuffleOp.getResult(1), [&] { - return llvm::SmallVector{ + return SmallVector{ accumFactory(value, shuffleOp.getResult(0))}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); } - return llvm::SmallVector{value}; + return SmallVector{value}; }, // Generate a reduction over the entire subgroup. This is a specialization // of the above reduction with unconditional accumulation. @@ -315,7 +317,7 @@ offset, subgroupSize, xorAttr); value = accumFactory(value, shuffleOp.getResult(0)); } - return llvm::SmallVector{value}; + return SmallVector{value}; }); return rewriter.getInsertionBlock()->getArgument(0); } @@ -344,7 +346,7 @@ PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto funcOp = llvm::cast(op); + auto funcOp = cast(op); auto callback = [&](gpu::AllReduceOp reduceOp) { GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); return WalkResult::interrupt();