diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -299,9 +299,9 @@ } LogicalResult ForOp::verify() { - if (auto cst = getStep().getDefiningOp()) - if (cst.value() <= 0) - return emitOpError("constant step operand must be positive"); + IntegerAttr step; + if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0) + return emitOpError("constant step operand must be positive"); auto opNumResults = getNumResults(); if (opNumResults == 0) @@ -719,11 +719,10 @@ /// Returns llvm::None when the difference between two AffineValueMap is /// dynamic. static Optional computeConstDiff(Value l, Value u) { - auto clb = l.getDefiningOp(); - auto cub = u.getDefiningOp(); - if (cub && clb) { - llvm::APInt lbValue = clb.getValue().cast().getValue(); - llvm::APInt ubValue = cub.getValue().cast().getValue(); + IntegerAttr clb, cub; + if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { + llvm::APInt lbValue = clb.getValue(); + llvm::APInt ubValue = cub.getValue(); return (ubValue - lbValue).getSExtValue(); } @@ -763,13 +762,13 @@ return success(); } - auto step = op.getStep().getDefiningOp(); - if (!step) + IntegerAttr step; + if (!matchPattern(op.getStep(), m_Constant(&step))) return failure(); // If the loop is known to have 1 iteration, inline its body and remove the // loop. - llvm::APInt stepValue = step.getValue().cast().getValue(); + llvm::APInt stepValue = step.getValue(); if (stepValue.sge(*diff)) { SmallVector blockArgs; blockArgs.reserve(op.getNumIterOperands() + 1); @@ -1674,11 +1673,11 @@ LogicalResult matchAndRewrite(IfOp op, PatternRewriter &rewriter) const override { - auto constant = op.getCondition().getDefiningOp(); - if (!constant) + BoolAttr condition; + if (!matchPattern(op.getCondition(), m_Constant(&condition))) return failure(); - if (constant.getValue().cast().getValue()) + if (condition.getValue()) replaceOpWithRegion(rewriter, op, op.getThenRegion()); else if (!op.getElseRegion().empty()) replaceOpWithRegion(rewriter, op, op.getElseRegion()); @@ -1777,7 +1776,7 @@ PatternRewriter &rewriter) const override { // Early exit if the condition is constant since replacing a constant // in the body with another constant isn't a simplification. - if (op.getCondition().getDefiningOp()) + if (matchPattern(op.getCondition(), m_Constant())) return failure(); bool changed = false; @@ -1881,25 +1880,23 @@ continue; } - auto trueYield = trueResult.getDefiningOp(); - if (!trueYield) + BoolAttr trueYield, falseYield; + if (!matchPattern(trueResult, m_Constant(&trueYield)) || + !matchPattern(falseResult, m_Constant(&falseYield))) continue; - if (!trueYield.getType().isInteger(1)) - continue; - - auto falseYield = falseResult.getDefiningOp(); - if (!falseYield) - continue; - - bool trueVal = trueYield.getValue().cast().getValue(); - bool falseVal = falseYield.getValue().cast().getValue(); + bool trueVal = trueYield.getValue(); + bool falseVal = falseYield.getValue(); if (!trueVal && falseVal) { if (!opResult.use_empty()) { + Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); Value notCond = rewriter.create( op.getLoc(), op.getCondition(), - rewriter.create( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))); + constDialect + ->materializeConstant(rewriter, + rewriter.getIntegerAttr(i1Ty, 1), i1Ty, + op.getLoc()) + ->getResult(0)); opResult.replaceAllUsesWith(notCond); changed = true; }