diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -44,6 +44,14 @@ /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr); +/// Given a value, try to extract a constant Attribute. If this fails, return +/// the original value. +OpFoldResult getAsOpFoldResult(Value val); + +/// Given an array of values, try to extract a constant Attribute from each +/// value. If this fails, return the original value. +SmallVector getAsOpFoldResult(ArrayRef values); + /// If ofr is a constant integer or an IntegerAttr, return the integer. Optional getConstantIntValue(OpFoldResult ofr); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -89,20 +89,6 @@ template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); -/// Helper function to convert a Value into an OpFoldResult, if the Value is -/// known to be a constant index value. -static SmallVector getAsOpFoldResult(ArrayRef values) { - return llvm::to_vector<4>( - llvm::map_range(values, [](Value v) -> OpFoldResult { - APInt intValue; - if (v.getType().isa() && - matchPattern(v, m_ConstantInt(&intValue))) { - return IntegerAttr::get(v.getType(), intValue.getSExtValue()); - } - return v; - })); -} - /// Helper function to convert a vector of `OpFoldResult`s into a vector of /// `Value`s. static SmallVector getAsValues(OpBuilder &b, Location loc, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" @@ -798,14 +799,6 @@ return builder.create(loc, *intVal); } -/// Given a value, try to extract a constant index-type integer as an Attribute. -/// If this fails, return the original value. -static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) { - if (auto constInt = getConstantIntValue(val)) - return builder.getIndexAttr(*constInt); - return val; -} - LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { auto padOp = sliceOp.source().getDefiningOp(); @@ -895,7 +888,7 @@ // ExtractSliceOp length will be zero in that case. (Effectively reading no // data from the source.) Value newOffset = min(max(sub(offset, low), zero), srcSize); - newOffsets.push_back(asOpFoldResult(rewriter, newOffset)); + newOffsets.push_back(getAsOpFoldResult(newOffset)); // The original ExtractSliceOp was reading until position `offset + length`. // Therefore, the corresponding position within the source tensor is: @@ -915,7 +908,7 @@ // The new ExtractSliceOp length is `endLoc - newOffset`. Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize); Value newLength = sub(endLoc, newOffset); - newLengths.push_back(asOpFoldResult(rewriter, newLength)); + newLengths.push_back(getAsOpFoldResult(newLength)); // Check if newLength is zero. In that case, no SubTensorOp should be // executed. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -47,6 +47,22 @@ })); } +/// Given a value, try to extract a constant Attribute. If this fails, return +/// the original value. +OpFoldResult getAsOpFoldResult(Value val) { + Attribute attr; + if (matchPattern(val, m_Constant(&attr))) + return attr; + return val; +} + +/// Given an array of values, try to extract a constant Attribute from each +/// value. If this fails, return the original value. +SmallVector getAsOpFoldResult(ArrayRef values) { + return llvm::to_vector<4>( + llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); })); +} + /// If ofr is a constant integer or an IntegerAttr, return the integer. Optional getConstantIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer.