diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1772,6 +1772,14 @@ ]; let extraClassDeclaration = commonExtraClassDeclaration # [{ + // Method to get the shape of the result as `SmallVector`. + // This is a static method to allow getting the shape of the destination + // expected while creating a `pack` op. + static SmallVector getResultShape(OpBuilder &builder, + Location loc, ArrayRef sourceDims, + ArrayRef innerTileDims, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + // Method to get the `ShapedType` of the result based on the inner tiles, // position of the inner tiles (innerDimsPos) and interchange vector of // outer loops (outerDimsPerm). diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3479,14 +3479,29 @@ return success(); } -/// Get the expected packed type based on source type, tile factors, position of -/// the inner tiles and permutation of the outer tiled loop. -ShapedType PackOp::inferPackedType(ShapedType sourceType, - ArrayRef innerTileSizes, - ArrayRef innerDimsPos, - ArrayRef outerDimsPerm) { - SmallVector resultShape = llvm::to_vector(sourceType.getShape()); - for (const auto &tiledDim : llvm::enumerate(innerDimsPos)) { +/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all +/// Value's to kDynamic, even if they are arith.constant values. +static SmallVector +asShapeWithAnyValueAsDynamic(ArrayRef ofrs) { + SmallVector result; + for (auto o : ofrs) { + // Have to do this first, as getConstantIntValue special-cases constants. + if (o.dyn_cast()) + result.push_back(ShapedType::kDynamic); + else + result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic)); + } + return result; +} + +/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of +/// the packed type. Having a shared helper helps implement these two methods in +/// a way that ensures that they agree on which dimensions are dynamic. +static SmallVector getPackOpResultTypeShape( + ArrayRef sourceShape, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { + SmallVector resultShape = llvm::to_vector(sourceShape); + for (auto tiledDim : llvm::enumerate(innerDimsPos)) { if (ShapedType::isDynamic(resultShape[tiledDim.value()])) continue; if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { @@ -3497,11 +3512,60 @@ innerTileSizes[tiledDim.index()]); } + // Swap tile loops if outer_dims_perm is available. if (!outerDimsPerm.empty()) applyPermutationToVector(resultShape, outerDimsPerm); // Append the inner tile dimensions. resultShape.append(innerTileSizes.begin(), innerTileSizes.end()); + return resultShape; +} + +SmallVector PackOp::getResultShape( + OpBuilder &builder, Location loc, ArrayRef sourceDims, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultDims = llvm::to_vector(sourceDims); + + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + AffineExpr ceilDivExpr = s0.ceilDiv(s1); + for (auto tiledDim : llvm::enumerate(innerDimsPos)) { + resultDims[tiledDim.value()] = makeComposedFoldedAffineApply( + builder, loc, ceilDivExpr, + {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]}); + } + if (!outerDimsPerm.empty()) + applyPermutationToVector(resultDims, outerDimsPerm); + resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); + + SmallVector resultTypeShape = + getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), + asShapeWithAnyValueAsDynamic(innerTileSizes), + innerDimsPos, outerDimsPerm); + + // Fix-up `resultDims` to ensure that they are Value's if and only if the + // result type shape says it's a dynamic dim. This is needed as callers may + // use dispatchIndexOpFoldResults on the result, and rely on exact number of + // dynamic dims returned by that. + for (unsigned i = 0; i < resultDims.size(); ++i) { + if (!ShapedType::isDynamic(resultTypeShape[i])) + continue; + resultDims[i] = + getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]); + } + + return resultDims; +} + +/// Get the expected packed type based on source type, tile factors, position of +/// the inner tiles and permutation of the outer tiled loop. +ShapedType PackOp::inferPackedType(ShapedType sourceType, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultShape = getPackOpResultTypeShape( + sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); return RankedTensorType::get(resultShape, sourceType.getElementType()); }