diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1415,12 +1415,11 @@ /// Return the rank of the result type. unsigned getResultRank() { return getType().getRank(); } - /// Return the expected rank of each of the`static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned resultRank = getType().getRank(); - return {1, resultRank, resultRank}; - } + // Return the expected size of each of the `static_offsets`, `static_sizes` + // and `static_strides` attributes. + int64_t getOffsetsRank() { return 1; } + int64_t getSizesRank() { return getType().getRank(); } + int64_t getStridesRank() { return getType().getRank(); } /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. @@ -2052,13 +2051,6 @@ ArrayRef staticSizes, ArrayRef staticStrides); - /// Return the expected rank of each of the`static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getSourceType().getRank(); - return {rank, rank, rank}; - } - /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -448,12 +448,11 @@ ArrayRef staticSizes, ArrayRef staticStrides); - /// Return the expected rank of each of the`static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getSourceType().getRank(); - return {rank, rank, rank}; - } + // Return the expected size of each of the `static_offsets`, `static_sizes` + // and `static_strides` attributes. + int64_t getOffsetsRank() { return getSourceType().getRank(); } + int64_t getSizesRank() { return getSourceType().getRank(); } + int64_t getStridesRank() { return getSourceType().getRank(); } /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. @@ -877,12 +876,11 @@ return getResultType(); } - /// Return the expected rank of each of the`static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getResultType().getRank(); - return {rank, rank, rank}; - } + // Return the expected size of each of the `static_offsets`, `static_sizes` + // and `static_strides` attributes. + int64_t getOffsetsRank() { return getDestType().getRank(); } + int64_t getSizesRank() { return getDestType().getRank(); } + int64_t getStridesRank() { return getDestType().getRank(); } /// Return the dimensions of the dest that are omitted to insert a source /// when the result is rank-extended. @@ -1461,12 +1459,11 @@ getOperation()->getParentOp()); } - /// Return the expected rank of each of the `static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getDestType().getRank(); - return {rank, rank, rank}; - } + // Return the expected size of each of the `static_offsets`, `static_sizes` + // and `static_strides` attributes. + int64_t getOffsetsRank() { return getDestType().getRank(); } + int64_t getSizesRank() { return getDestType().getRank(); } + int64_t getStridesRank() { return getDestType().getRank(); } /// Return the number of leading operands before `offsets`, `sizes` and /// `strides` operands. diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -37,6 +37,11 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef staticVals, unsigned idx); +/// Default implementation for `getOffsetsRank`, `getSizesRank` and +/// `getStridesRank` of `OffsetSizeAndStrideOpInterface` that tries to infer the +/// sizes of the corresponding arrays with the `ViewLikeOpInterface`. +FailureOr defaultArrayRank(Operation *op); + } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -35,7 +35,7 @@ Common interface for ops that allow specifying mixed dynamic and static offsets, sizes and strides variadic operands. Ops that implement this interface need to expose the following methods: - 1. `getArrayAttrMaxRanks` to specify the length of static integer + 1. `getOffsetSizeAndStrideRank` to specify the length of static integer attributes. 2. `offsets`, `sizes` and `strides` variadic operands. 3. `static_offsets`, resp. `static_sizes` and `static_strides` integer @@ -44,10 +44,11 @@ starting index of the OffsetSizeAndStrideOpInterface operands The invariants of this interface are: - 1. `static_offsets`, `static_sizes` and `static_strides` have length - exactly `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). - 2. `offsets`, `sizes` and `strides` have each length at most - `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). + 1. `static_offsets`, `static_sizes` and `static_strides` have a length of + exactly `getOffsetsRank`, `getSizesRank` and `getStridesRank`, + respectively. + 2. `offsets`, `sizes` and `strides` have each length of at most + `getOffsetsRank`, `getSizesRank` and `getStridesRank`, respectively. 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely `ShapedType::kDynamic`, then the corresponding entry is a dynamic @@ -76,12 +77,48 @@ >, InterfaceMethod< /*desc=*/[{ - Return the expected rank of each of the`static_offsets`, `static_sizes` - and `static_strides` attributes. + Return the expected size of the `static_offsets` attribute. }], - /*retTy=*/"std::array", - /*methodName=*/"getArrayAttrMaxRanks", - /*args=*/(ins) + /*retTy=*/"int64_t", + /*methodName=*/"getOffsetsRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + FailureOr rank = ::mlir::detail::defaultArrayRank($_op); + if (succeeded(rank)) + return *rank; + llvm_unreachable("getOffsetsRank not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the expected size of the `static_sizes` attribute. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getSizesRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + FailureOr rank = ::mlir::detail::defaultArrayRank($_op); + if (succeeded(rank)) + return *rank; + llvm_unreachable("getOffsetsRank not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the expected size of the `static_strides` attribute. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStridesRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + FailureOr rank = ::mlir::detail::defaultArrayRank($_op); + if (succeeded(rank)) + return *rank; + llvm_unreachable("getOffsetsRank not implemented"); + }] >, InterfaceMethod< /*desc=*/[{ diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2880,13 +2880,14 @@ /// with `b` at location `loc`. SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc) { - std::array ranks = op.getArrayAttrMaxRanks(); - assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); - assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); + int64_t offsetsRank = op.getOffsetsRank(); + assert(offsetsRank == op.getSizesRank() && + "expected offset and sizes of equal ranks"); + assert(op.getSizesRank() == op.getStridesRank() && + "expected sizes and strides of equal ranks"); SmallVector res; - unsigned rank = ranks[0]; - res.reserve(rank); - for (unsigned idx = 0; idx < rank; ++idx) { + res.reserve(offsetsRank); + for (unsigned idx = 0; idx < offsetsRank; ++idx) { Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -38,12 +38,12 @@ LogicalResult mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { - std::array maxRanks = op.getArrayAttrMaxRanks(); // Offsets can come in 2 flavors: - // 1. Either single entry (when maxRanks == 1). + // 1. Either single entry (when getOffsetsRank() == 1). // 2. Or as an array whose rank must match that of the mixed sizes. // So that the result type is well-formed. - if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT + if (!(op.getMixedOffsets().size() == 1 && + op.getOffsetsRank() == 1) && // NOLINT op.getMixedOffsets().size() != op.getMixedSizes().size()) return op->emitError( "expected mixed offsets rank to match mixed sizes rank (") @@ -57,14 +57,16 @@ << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() << ") so the rank of the result type is well-formed."; - if (failed(verifyListOfOperandsOrIntegers( - op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) + if (failed(verifyListOfOperandsOrIntegers(op, "offset", op.getOffsetsRank(), + op.getStaticOffsets(), + op.getOffsets()))) return failure(); if (failed(verifyListOfOperandsOrIntegers( - op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes()))) + op, "size", op.getSizesRank(), op.getStaticSizes(), op.getSizes()))) return failure(); - if (failed(verifyListOfOperandsOrIntegers( - op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides()))) + if (failed(verifyListOfOperandsOrIntegers(op, "stride", op.getStridesRank(), + op.getStaticStrides(), + op.getStrides()))) return failure(); return success(); } @@ -202,3 +204,12 @@ return std::count_if(staticVals.begin(), staticVals.begin() + idx, [&](int64_t val) { return ShapedType::isDynamic(val); }); } + +FailureOr mlir::detail::defaultArrayRank(Operation *op) { + if (auto viewLikeOp = dyn_cast(op)) + if (auto shapedType = + dyn_cast(viewLikeOp.getViewSource().getType())) + if (shapedType.hasRank()) + return shapedType.getRank(); + return failure(); +}