diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -15,6 +15,7 @@ #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" @@ -57,8 +58,14 @@ SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec); -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -SmallVector extractFromI64ArrayAttr(Attribute attr); +/// Extract integer values from the assumed ArrayAttr of IntegerAttr. +template +SmallVector extractFromIntegerArrayAttr(Attribute attr) { + return llvm::to_vector( + llvm::map_range(cast(attr), [](Attribute a) -> IntTy { + return cast(a).getInt(); + })); +} /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -416,9 +416,10 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); + SmallVector tileSizes = + extractFromIntegerArrayAttr(getTileSizes()); SmallVector tileInterchange = - extractFromI64ArrayAttr(getTileInterchange()); + extractFromIntegerArrayAttr(getTileInterchange()); scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; @@ -471,7 +472,7 @@ LogicalResult transform::FuseOp::verify() { SmallVector permutation = - extractFromI64ArrayAttr(getTileInterchange()); + extractFromIntegerArrayAttr(getTileInterchange()); auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { @@ -479,7 +480,8 @@ << getTileInterchange(); } - SmallVector sizes = extractFromI64ArrayAttr(getTileSizes()); + SmallVector sizes = + extractFromIntegerArrayAttr(getTileSizes()); size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; @@ -1571,7 +1573,8 @@ // Convert the integer packing flags to booleans. SmallVector packPaddings; - for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) + for (int64_t packPadding : + extractFromIntegerArrayAttr(getPackPaddings())) packPaddings.push_back(static_cast(packPadding)); // Convert the padding values to attributes. @@ -1611,15 +1614,17 @@ // Extract the transpose vectors. SmallVector> transposePaddings; for (Attribute transposeVector : cast(getTransposePaddings())) - transposePaddings.push_back( - extractFromI64ArrayAttr(cast(transposeVector))); + transposePaddings.push_back(extractFromIntegerArrayAttr( + cast(transposeVector))); LinalgOp paddedOp; LinalgPaddingOptions options; - options.paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); + options.paddingDimensions = + extractFromIntegerArrayAttr(getPaddingDimensions()); SmallVector padToMultipleOf(options.paddingDimensions.size(), 1); if (getPadToMultipleOf().has_value()) - padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); + padToMultipleOf = + extractFromIntegerArrayAttr(*getPadToMultipleOf()); options.padToMultipleOf = padToMultipleOf; options.paddingValues = paddingValues; options.packPaddings = packPaddings; @@ -1650,7 +1655,7 @@ LogicalResult transform::PadOp::verify() { SmallVector packPaddings = - extractFromI64ArrayAttr(getPackPaddings()); + extractFromIntegerArrayAttr(getPackPaddings()); if (any_of(packPaddings, [](int64_t packPadding) { return packPadding != 0 && packPadding != 1; })) { @@ -1660,7 +1665,7 @@ } SmallVector paddingDimensions = - extractFromI64ArrayAttr(getPaddingDimensions()); + extractFromIntegerArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { return emitOpError() << "expects padding_dimensions to contain positive " @@ -1674,7 +1679,7 @@ } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { - SmallVector transpose = extractFromI64ArrayAttr(attr); + SmallVector transpose = extractFromIntegerArrayAttr(attr); auto sequence = llvm::to_vector(llvm::seq(0, transpose.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(), transpose.end())) { @@ -1791,7 +1796,7 @@ LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) promotionOptions = promotionOptions.setOperandsToPromote( - extractFromI64ArrayAttr(getOperandsToPromote())); + extractFromIntegerArrayAttr(getOperandsToPromote())); if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -68,14 +68,6 @@ dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(cast(attr), [](Attribute a) -> int64_t { - return cast(a).getInt(); - })); -} - /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. OpFoldResult getAsOpFoldResult(Value val) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -165,7 +165,7 @@ if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = - extractFromI64ArrayAttr(attribute.getValue()); + extractFromIntegerArrayAttr(attribute.getValue()); generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); if (values.size() > 1) generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); @@ -175,7 +175,7 @@ if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = - extractFromI64ArrayAttr(attribute.getValue()); + extractFromIntegerArrayAttr(attribute.getValue()); generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); if (values.size() > 1) generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());