diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -26,12 +26,6 @@ /// Matches a ConstantIndexOp. detail::op_matcher matchConstantIndex(); -/// Detects the `values` produced by a ConstantIndexOp and places the new -/// constant in place of the corresponding sentinel value. -/// TODO(pifon2a): Remove this function and use foldDynamicIndexList. -void canonicalizeSubViewPart(SmallVectorImpl &values, - function_ref isDynamic); - /// Returns `success` when any of the elements in `ofrs` was produced by /// arith::ConstantIndexOp. In that case the constant attribute replaces the /// Value. Returns `failure` when no folding happened. @@ -50,20 +44,15 @@ LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { - // No constant operand, just return; - if (llvm::none_of(op.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the existing. SmallVector mixedOffsets(op.getMixedOffsets()); SmallVector mixedSizes(op.getMixedSizes()); SmallVector mixedStrides(op.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic); + + // No constant operands were folded, just return; + if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) && + failed(foldDynamicIndexList(rewriter, mixedSizes)) && + failed(foldDynamicIndexList(rewriter, mixedStrides))) + return failure(); // Create the new op in canonical form. ResultTypeFunc resultTypeFunc; diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -23,21 +23,6 @@ return detail::op_matcher(); } -// Detects the `values` produced by a ConstantIndexOp and places the new -// constant in place of the corresponding sentinel value. -void mlir::canonicalizeSubViewPart( - SmallVectorImpl &values, - llvm::function_ref isDynamic) { - for (OpFoldResult &ofr : values) { - if (ofr.is()) - continue; - // Newly static, move from Value to constant. - if (auto cstOp = - ofr.dyn_cast().getDefiningOp()) - ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value()); - } -} - // Returns `success` when any of the elements in `ofrs` was produced by // arith::ConstantIndexOp. In that case the constant attribute replaces the // Value. Returns `failure` when no folding happened. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2227,21 +2227,15 @@ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { - // No constant operand, just return. - if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the - // existing. SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); SmallVector mixedSizes(insertSliceOp.getMixedSizes()); SmallVector mixedStrides(insertSliceOp.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic); + + // No constant operands were folded, just return; + if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) && + failed(foldDynamicIndexList(rewriter, mixedSizes)) && + failed(foldDynamicIndexList(rewriter, mixedStrides))) + return failure(); // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(