diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -949,11 +949,9 @@ let assemblyFormat = [{ $target oilist( `num_threads` custom($num_threads, - $static_num_threads, - "ShapedType::kDynamic") | + $static_num_threads) | `tile_sizes` custom($tile_sizes, - $static_tile_sizes, - "ShapedType::kDynamic")) + $static_tile_sizes)) (`(` `mapping` `=` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1267,14 +1267,11 @@ let assemblyFormat = [{ $source `to` `offset` `` `:` - custom($offsets, $static_offsets, - "ShapedType::kDynamic") + custom($offsets, $static_offsets) `` `,` `sizes` `` `:` - custom($sizes, $static_sizes, - "ShapedType::kDynamic") + custom($sizes, $static_sizes) `` `,` `strides` `` `:` - custom($strides, $static_strides, - "ShapedType::kDynamic") + custom($strides, $static_strides) attr-dict `:` type($source) `to` type($result) }]; @@ -1865,12 +1862,9 @@ let assemblyFormat = [{ $source `` - custom($offsets, $static_offsets, - "ShapedType::kDynamic") - custom($sizes, $static_sizes, - "ShapedType::kDynamic") - custom($strides, $static_strides, - "ShapedType::kDynamic") + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) attr-dict `:` type($source) `to` type($result) }]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -334,12 +334,9 @@ let assemblyFormat = [{ $source `` - custom($offsets, $static_offsets, - "ShapedType::kDynamic") - custom($sizes, $static_sizes, - "ShapedType::kDynamic") - custom($strides, $static_strides, - "ShapedType::kDynamic") + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) attr-dict `:` type($source) `to` type($result) }]; @@ -818,12 +815,9 @@ let assemblyFormat = [{ $source `into` $dest `` - custom($offsets, $static_offsets, - "ShapedType::kDynamic") - custom($sizes, $static_sizes, - "ShapedType::kDynamic") - custom($strides, $static_strides, - "ShapedType::kDynamic") + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) attr-dict `:` type($source) `into` type($dest) }]; @@ -1221,10 +1215,8 @@ let assemblyFormat = [{ $source (`nofold` $nofold^)? - `low` `` custom($low, $static_low, - "ShapedType::kDynamic") - `high` `` custom($high, $static_high, - "ShapedType::kDynamic") + `low` `` custom($low, $static_low) + `high` `` custom($high, $static_high) $region attr-dict `:` type($source) `to` type($result) }]; @@ -1411,12 +1403,9 @@ ); let assemblyFormat = [{ $source `into` $dest `` - custom($offsets, $static_offsets, - "ShapedType::kDynamic") - custom($sizes, $static_sizes, - "ShapedType::kDynamic") - custom($strides, $static_strides, - "ShapedType::kDynamic") + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) attr-dict `:` type($source) `into` type($dest) }]; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -21,39 +21,17 @@ namespace mlir { -/// Return a vector of OpFoldResults given the special value -/// that indicates whether of the value is dynamic or not. +/// Return a vector of OpFoldResults with the same size a staticValues, but all +/// elements for which ShapedType::isDynamic is true, will be replaced by +/// dynamicValues. SmallVector getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues, - int64_t dynamicValueIndicator); - -/// Return a vector of all the static and dynamic offsets/strides. -SmallVector getMixedStridesOrOffsets(ArrayAttr staticValues, - ValueRange dynamicValues); - -/// Return a vector of all the static and dynamic sizes. -SmallVector getMixedSizes(ArrayAttr staticValues, - ValueRange dynamicValues); + ValueRange dynamicValues); /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. std::pair> decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues, - const int64_t dynamicValueIndicator); - -/// Decompose a vector of mixed static and dynamic strides/offsets into the -/// corresponding pair of arrays. This is the inverse function of -/// `getMixedStridesOrOffsets`. -std::pair> decomposeMixedStridesOrOffsets( - OpBuilder &b, const SmallVectorImpl &mixedValues); - -/// Decompose a vector of mixed static or dynamic strides/offsets into the -/// corresponding pair of arrays. This is the inverse function of -/// `getMixedSizes`. -std::pair> -decomposeMixedSizes(OpBuilder &b, - const SmallVectorImpl &mixedValues); + const SmallVectorImpl &mixedValues); class OffsetSizeAndStrideOpInterface; @@ -83,8 +61,7 @@ /// idiomatic printing of mixed value and integer attributes in a list. E.g. /// `[%arg0, 7, 42, %arg42]`. void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers, - int64_t dynVal); + OperandRange values, ArrayAttr integers); /// Pasrer hook for custom directive in assemblyFormat. /// @@ -102,13 +79,13 @@ ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers, int64_t dynVal); + ArrayAttr &integers); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. -LogicalResult verifyListOfOperandsOrIntegers( - Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, - ValueRange values, function_ref isDynamic); +LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, + unsigned expectedNumElements, + ArrayAttr attr, ValueRange values); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -165,8 +165,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedStridesOrOffsets($_op.getStaticOffsets(), - $_op.getOffsets()); + return ::mlir::getMixedValues($_op.getStaticOffsets(), + $_op.getOffsets()); }] >, InterfaceMethod< @@ -178,7 +178,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedSizes($_op.getStaticSizes(), $_op.sizes()); + return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes()); }] >, InterfaceMethod< @@ -190,15 +190,13 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedStridesOrOffsets($_op.getStaticStrides(), - $_op.getStrides()); + return ::mlir::getMixedValues($_op.getStaticStrides(), + $_op.getStrides()); }] >, InterfaceMethod< - /*desc=*/[{ - Return true if the offset `idx` is dynamic. - }], + /*desc=*/"Return true if the offset `idx` is dynamic.", /*retTy=*/"bool", /*methodName=*/"isDynamicOffset", /*args=*/(ins "unsigned":$idx), @@ -210,9 +208,7 @@ }] >, InterfaceMethod< - /*desc=*/[{ - Return true if the size `idx` is dynamic. - }], + /*desc=*/"Return true if the size `idx` is dynamic.", /*retTy=*/"bool", /*methodName=*/"isDynamicSize", /*args=*/(ins "unsigned":$idx), @@ -224,9 +220,7 @@ }] >, InterfaceMethod< - /*desc=*/[{ - Return true if the stride `idx` is dynamic. - }], + /*desc=*/"Return true if the stride `idx` is dynamic.", /*retTy=*/"bool", /*methodName=*/"isDynamicStride", /*args=*/(ins "unsigned":$idx), diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1321,8 +1321,7 @@ auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || - parseDynamicIndexList(parser, dynamicSizes, staticSizes, - ShapedType::kDynamic) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || parser.parseOptionalAttrDict(result.attributes)) return ParseResult::failure(); @@ -1336,8 +1335,7 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); - printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), - ShapedType::kDynamic); + printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); } @@ -1549,11 +1547,11 @@ } SmallVector TileToForeachThreadOp::getMixedNumThreads() { - return getMixedSizes(getStaticNumThreads(), getNumThreads()); + return getMixedValues(getStaticNumThreads(), getNumThreads()); } SmallVector TileToForeachThreadOp::getMixedTileSizes() { - return getMixedSizes(getStaticTileSizes(), getTileSizes()); + return getMixedValues(getStaticTileSizes(), getTileSizes()); } LogicalResult TileToForeachThreadOp::verify() { @@ -1680,8 +1678,7 @@ auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || - parseDynamicIndexList(parser, dynamicSizes, staticSizes, - ShapedType::kDynamic) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || parser.parseOptionalAttrDict(result.attributes)) return ParseResult::failure(); @@ -1695,8 +1692,7 @@ void TileToScfForOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); - printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), - ShapedType::kDynamic); + printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -17,16 +17,18 @@ /// Include the definitions of the loop-like interfaces. #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" -LogicalResult mlir::verifyListOfOperandsOrIntegers( - Operation *op, StringRef name, unsigned numElements, ArrayAttr attr, - ValueRange values, llvm::function_ref isDynamic) { +LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, + StringRef name, + unsigned numElements, + ArrayAttr attr, + ValueRange values) { /// Check static and dynamic offsets/sizes/strides does not overflow type. if (attr.size() != numElements) return op->emitError("expected ") << numElements << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); + return ShapedType::isDynamic(attr.cast().getInt()); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") @@ -56,23 +58,19 @@ << ") so the rank of the result type is well-formed."; if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0], - op.static_offsets(), op.offsets(), - ShapedType::isDynamic))) + op.static_offsets(), op.offsets()))) return failure(); if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], - op.static_sizes(), op.sizes(), - ShapedType::isDynamic))) + op.static_sizes(), op.sizes()))) return failure(); if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2], - op.static_strides(), op.strides(), - ShapedType::isDynamic))) + op.static_strides(), op.strides()))) return failure(); return success(); } void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers, - int64_t dynVal) { + OperandRange values, ArrayAttr integers) { printer << '['; if (integers.empty()) { printer << "]"; @@ -81,7 +79,7 @@ unsigned idx = 0; llvm::interleaveComma(integers, printer, [&](Attribute a) { int64_t val = a.cast().getInt(); - if (val == dynVal) + if (ShapedType::isDynamic(val)) printer << values[idx++]; else printer << val; @@ -92,7 +90,7 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers, int64_t dynVal) { + ArrayAttr &integers) { if (failed(parser.parseLSquare())) return failure(); // 0-D. @@ -107,7 +105,7 @@ auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); - attrVals.push_back(dynVal); + attrVals.push_back(ShapedType::kDynamic); } else { IntegerAttr attr; if (failed(parser.parseAttribute(attr))) @@ -147,57 +145,33 @@ return true; } -SmallVector -mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues, - const int64_t dynamicValueIndicator) { +SmallVector mlir::getMixedValues(ArrayAttr staticValues, + ValueRange dynamicValues) { SmallVector res; res.reserve(staticValues.size()); unsigned numDynamic = 0; unsigned count = static_cast(staticValues.size()); for (unsigned idx = 0; idx < count; ++idx) { APInt value = staticValues[idx].cast().getValue(); - res.push_back(value.getSExtValue() == dynamicValueIndicator + res.push_back(ShapedType::isDynamic(value.getSExtValue()) ? OpFoldResult{dynamicValues[numDynamic++]} : OpFoldResult{staticValues[idx]}); } return res; } -SmallVector -mlir::getMixedStridesOrOffsets(ArrayAttr staticValues, - ValueRange dynamicValues) { - return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic); -} - -SmallVector mlir::getMixedSizes(ArrayAttr staticValues, - ValueRange dynamicValues) { - return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic); -} - std::pair> mlir::decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues, - const int64_t dynamicValueIndicator) { + const SmallVectorImpl &mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { if (it.is()) { staticValues.push_back(it.get().cast().getInt()); } else { - staticValues.push_back(dynamicValueIndicator); + staticValues.push_back(ShapedType::kDynamic); dynamicValues.push_back(it.get()); } } return {b.getI64ArrayAttr(staticValues), dynamicValues}; } - -std::pair> mlir::decomposeMixedStridesOrOffsets( - OpBuilder &b, const SmallVectorImpl &mixedValues) { - return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic); -} - -std::pair> -mlir::decomposeMixedSizes(OpBuilder &b, - const SmallVectorImpl &mixedValues) { - return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic); -}