diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1788,6 +1788,13 @@ ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); + // Returns true if we have enough static information to catch undefined + // behavior when the tile size does not divide perfectly the dimension of + // the input tensor. + static bool requirePaddingValue(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef innerTiles); + static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3422,22 +3422,16 @@ return getStaticTilesImpl(*this); } -/// Check if we have enough static information to catch undefined behavior when -/// the tile size does not divide perfectly the dimension of the input tensor. -static bool -areNotFullTiles(ArrayRef inputShape, - DenseMap const &dimAndTileMapping) { - int64_t rank = inputShape.size(); - for (int64_t dim = 0; dim < rank; dim++) { - if (ShapedType::isDynamic(inputShape[dim])) +bool PackOp::requirePaddingValue(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef innerTiles) { + for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { + if (ShapedType::isDynamic(inputShape[pos])) continue; - auto it = dimAndTileMapping.find(dim); - if (it == dimAndTileMapping.end()) - continue; - std::optional constantTile = getConstantIntValue(it->second); + std::optional constantTile = getConstantIntValue(tileSize); if (!constantTile) continue; - if (inputShape[dim] % (*constantTile) != 0) + if (inputShape[pos] % (*constantTile) != 0) return true; } return false; @@ -3458,9 +3452,9 @@ << " but got: " << paddingValue.getType(); } - auto dimAndTileMapping = getDimAndTileMapping(); if (!paddingValue && - areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) { + requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(), + getMixedTiles())) { return emitOpError("invalid tile factor provided. Only full tiles are " "supported when padding_value is not set"); }