diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -10,8 +10,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" @@ -20,6 +20,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" using namespace mlir; @@ -82,14 +84,14 @@ /// Combines chained `spirv::AccessChainOp` operations into one /// `spirv::AccessChainOp` operation. -struct CombineChainedAccessChain - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct CombineChainedAccessChain final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp, PatternRewriter &rewriter) const override { - auto parentAccessChainOp = dyn_cast_or_null( - accessChainOp.getBasePtr().getDefiningOp()); + auto parentAccessChainOp = + accessChainOp.getBasePtr().getDefiningOp(); if (!parentAccessChainOp) { return failure(); @@ -97,8 +99,7 @@ // Combine indices. SmallVector indices(parentAccessChainOp.getIndices()); - indices.append(accessChainOp.getIndices().begin(), - accessChainOp.getIndices().end()); + llvm::append_range(indices, accessChainOp.getIndices()); rewriter.replaceOpWithNewOp( accessChainOp, parentAccessChainOp.getBasePtr(), indices); @@ -155,17 +156,16 @@ auto type = llvm::cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { - auto i = getIndices().begin()->cast(); - if (static_cast(i.getValue().getSExtValue()) < - constructOp.getConstituents().size()) + auto i = llvm::cast(*getIndices().begin()); + if (i.getValue().getSExtValue() < + static_cast(constructOp.getConstituents().size())) return constructOp.getConstituents()[i.getValue().getSExtValue()]; } } - auto indexVector = - llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) { - return static_cast(llvm::cast(attr).getInt()); - })); + auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) { + return static_cast(llvm::cast(attr).getInt()); + }); return extractCompositeElement(adaptor.getComposite(), indexVector); } @@ -289,13 +289,15 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { - if (*rhs) + if (*rhs) { // x || true = true return adaptor.getOperand2(); + } - // x || false = x - if (!*rhs) + if (!*rhs) { + // x || false = x return getOperand1(); + } } return Attribute(); @@ -331,14 +333,13 @@ // | merge block | // +-------------+ // -struct ConvertSelectionOpToSelect - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConvertSelectionOpToSelect final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp, PatternRewriter &rewriter) const override { - auto *op = selectionOp.getOperation(); - auto &body = op->getRegion(0); + Operation *op = selectionOp.getOperation(); + Region &body = op->getRegion(0); // Verifier allows an empty region for `spirv.mlir.selection`. if (body.empty()) { return failure(); @@ -346,11 +347,11 @@ // Check that region consists of 4 blocks: // header block, `true` block, `false` block and merge block. - if (std::distance(body.begin(), body.end()) != 4) { + if (llvm::range_size(body) != 4) { return failure(); } - auto *headerBlock = selectionOp.getHeaderBlock(); + Block *headerBlock = selectionOp.getHeaderBlock(); if (!onlyContainsBranchConditionalOp(headerBlock)) { return failure(); } @@ -358,16 +359,16 @@ auto brConditionalOp = cast(headerBlock->front()); - auto *trueBlock = brConditionalOp.getSuccessor(0); - auto *falseBlock = brConditionalOp.getSuccessor(1); - auto *mergeBlock = selectionOp.getMergeBlock(); + Block *trueBlock = brConditionalOp.getSuccessor(0); + Block *falseBlock = brConditionalOp.getSuccessor(1); + Block *mergeBlock = selectionOp.getMergeBlock(); if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock))) return failure(); - auto trueValue = getSrcValue(trueBlock); - auto falseValue = getSrcValue(falseBlock); - auto ptrValue = getDstPtr(trueBlock); + Value trueValue = getSrcValue(trueBlock); + Value falseValue = getSrcValue(falseBlock); + Value ptrValue = getDstPtr(trueBlock); auto storeOpAttributes = cast(trueBlock->front())->getAttrs(); @@ -393,7 +394,7 @@ Block *mergeBlock) const; bool onlyContainsBranchConditionalOp(Block *block) const { - return std::next(block->begin()) == block->end() && + return llvm::hasSingleElement(*block) && isa(block->front()); } @@ -419,8 +420,7 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { // Each block must consists of 2 operations. - if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) || - (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) { + if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) { return failure(); }