diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -122,6 +122,15 @@ bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional getConstantIntValue(OpFoldResult ofr); + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool isEqualConstantInt(OpFoldResult ofr, int64_t value); + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwitdh and type mismatch that come from the fact there is diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -30,6 +30,8 @@ class OffsetSizeAndStrideOpInterface; +bool isEqualConstantInt(OpFoldResult ofr, int64_t value); + namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -436,6 +436,30 @@ $_op.getOperation()), other, cmp); }] >, + InterfaceMethod< + /*desc=*/[{ Return true if all strides are guaranteed to be 1. }], + /*retTy=*/"bool", + /*methodName=*/"hasUnitStride", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) { + return ::mlir::isEqualConstantInt(ofr, 1); + }); + }] + >, + InterfaceMethod< + /*desc=*/[{ Return true if all offsets are guaranteed to be 0. }], + /*retTy=*/"bool", + /*methodName=*/"hasZeroOffset", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) { + return ::mlir::isEqualConstantInt(ofr, 0); + }); + }] + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -60,24 +60,35 @@ dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); } +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { + Attribute attr = ofr.dyn_cast(); + // Note: isa+cast-like pattern allows writing the condition below as 1 line. + if (!attr && ofr.get().getDefiningOp()) + attr = ofr.get().getDefiningOp().getValue(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; +} + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { + auto ofrValue = getConstantIntValue(ofr); + return ofrValue && *ofrValue == value; +} + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. -bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) { - auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; - }; - auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { + auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); if (cst1 && cst2 && *cst1 == *cst2) return true; - auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); + auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v2 && v1 == v2; }