diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -28,6 +28,12 @@ class ShapedTypeComponents; using ReifiedRankedShapedTypeDims = SmallVector>; +/// Reify the shape of the result of an operation (typically in terms of the +/// shape of its operands). +LogicalResult +reifyResultShapes(OpBuilder &b, Operation *op, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); + /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. class ShapeAdaptor { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -138,17 +138,15 @@ bool reifiedShapes = false; if (shapedValue.getType().isa() && shapedValue.isa()) { - if (auto rankedOp = dyn_cast_or_null( - shapedValue.getDefiningOp())) { - ReifiedRankedShapedTypeDims resultDims; - if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { - reifiedShapes = true; - auto &shape = - resultDims[shapedValue.cast().getResultNumber()]; - for (const auto &dim : enumerate(tensorType.getShape())) - if (ShapedType::isDynamic(dim.value())) - dynamicSizes.push_back(shape[dim.index()].get()); - } + ReifiedRankedShapedTypeDims resultDims; + if (succeeded( + reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) { + reifiedShapes = true; + auto &shape = + resultDims[shapedValue.cast().getResultNumber()]; + for (const auto &dim : enumerate(tensorType.getShape())) + if (ShapedType::isDynamic(dim.value())) + dynamicSizes.push_back(shape[dim.index()].get()); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -482,9 +482,7 @@ return failure(); ReifiedRankedShapedTypeDims reifiedShape; - ReifyRankedShapedTypeOpInterface interface = - cast(padOp.getOperation()); - if (failed(interface.reifyResultShapes(rewriter, reifiedShape))) + if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -125,19 +125,17 @@ return {}; // Try to reify dynamic sizes. - if (auto reifiableOp = - value.getDefiningOp()) { - ReifiedRankedShapedTypeDims reifiedShape; - if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) { - SmallVector dynSizes; - for (int64_t i = 0; i < tensorType.getRank(); ++i) { - if (tensorType.isDynamicDim(i)) - dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i] - .get()); - } - return dynSizes; + ReifiedRankedShapedTypeDims reifiedShape; + if (value.isa() && + succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { + SmallVector dynSizes; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) + dynSizes.push_back( + reifiedShape[value.cast().getResultNumber()][i] + .get()); } + return dynSizes; } // Create tensor.dim ops. @@ -293,8 +291,7 @@ Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); ReifiedRankedShapedTypeDims reifiedShape; - if (failed(cast(padOp.getOperation()) - .reifyResultShapes(rewriter, reifiedShape))) + if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); SmallVector dynamicSizes; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -62,10 +62,7 @@ padOp, "only supported for ops with all parallel iterator types"); } ReifiedRankedShapedTypeDims resultShape; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(padOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShape)) || + if (failed(reifyResultShapes(rewriter, padOp, resultShape)) || resultShape.size() != 1) { return rewriter.notifyMatchFailure( padOp, "failed to get shape of pad op result"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -205,8 +205,7 @@ } ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(cast(opToPad.getOperation()) - .reifyResultShapes(rewriter, reifiedResultShapes))) { + if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); return rewriter.notifyMatchFailure(opToPad, "failed to reify result shapes"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -84,33 +84,18 @@ OpResult dimValue = dimOp.getSource().template dyn_cast(); if (!dimValue) return failure(); - auto rankedShapeTypeOp = - dyn_cast(dimValue.getOwner()); - if (!rankedShapeTypeOp) - return failure(); - std::optional dimIndex = dimOp.getConstantIndex(); if (!dimIndex) return failure(); ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed( - rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes))) - return failure(); - - if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults()) + if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), + reifiedResultShapes))) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - auto sourceType = dimValue.getType().dyn_cast(); - if (reifiedResultShapes[resultNumber].size() != - static_cast(sourceType.getRank())) - return failure(); - - rewriter.replaceOp(dimOp, - getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), - reifiedResultShapes[resultNumber][*dimIndex])); + Value replacement = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); + rewriter.replaceOp(dimOp, replacement); return success(); } }; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -81,11 +81,7 @@ if (!tensorType.hasStaticShape()) { // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(opResult.getDefiningOp()); - if (!reifyShapedTypeInterface) - return failure(); - if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes))) return failure(); mixedSizes = reifiedShapes[opResult.getResultNumber()]; } else { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -34,10 +34,7 @@ SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); - + (void)reifyResultShapes(b, op, reifiedShapes); Location loc = op->getLoc(); Value zero = b.create(loc, 0); Value one = b.create(loc, 1); @@ -84,7 +81,7 @@ Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); ReifiedRankedShapedTypeDims resultShape; - (void)op.reifyResultShapes(builder, resultShape); + (void)reifyResultShapes(builder, op, resultShape); SmallVector loopBounds(rank); for (auto dim : llvm::seq(0, rank)) { loopBounds[dim].offset = zero; @@ -216,7 +213,7 @@ resultOffsets.append(outputRank - inputRank, zeroAttr); ReifiedRankedShapedTypeDims outputShape; - (void)packOp.reifyResultShapes(b, outputShape); + (void)reifyResultShapes(b, packOp, outputShape); resultSizes.assign(sizes.begin(), sizes.end()); for (auto dataTileDim : llvm::seq(inputRank, outputRank)) resultSizes.push_back(outputShape[0][dataTileDim]); diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -26,10 +26,7 @@ return failure(); Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || + if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); // TODO: Do not drop tensor type encoding. diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -112,9 +112,7 @@ // Materialize the output shape of the collapse_shape operation. This will // create IR describing the output shape in terms of the input shape. ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + if (failed(reifyResultShapes(b, op, reifiedShapes))) return failure(); SmallVector &collapseShapeOutputShape = reifiedShapes[0]; SmallVector reassociationIndices = diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -22,6 +22,46 @@ #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" } // namespace mlir +LogicalResult +mlir::reifyResultShapes(OpBuilder &b, Operation *op, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + auto reifiableOp = dyn_cast(op); + if (!reifiableOp) + return failure(); + LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); +#ifndef NDEBUG + // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produces + // a correct result. + int64_t resultIdx = 0; + for (OpResult result : op->getResults()) { + auto shapedType = result.getType().dyn_cast(); + if (!shapedType) + continue; + if (!shapedType.hasRank()) { + // Nothing to check for unranked shaped values. + ++resultIdx; + continue; + } + // Assert one OpFoldResult per dimension. + assert(shapedType.getRank() == reifiedReturnShapes[resultIdx].size() && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); + for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { + // reifyResultShapes must return: + // * Attribute for static dimensions + // * Value for dynamic dimensions + assert(shapedType.isDynamicDim(dim) == + reifiedReturnShapes[resultIdx][dim].is() && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); + } + ++resultIdx; + } + // Assert that every shaped value result was reified. + assert(resultIdx == reifiedReturnShapes.size() && + "incorrect implementation of ReifyRankedShapedTypeOpInterface"); +#endif // NDEBUG + return status; +} + bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -188,7 +188,7 @@ // Materialize the output shape values of the slice operation. ReifiedRankedShapedTypeDims reifiedShapes; - if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) + if (failed(reifyResultShapes(rewriter, op, reifiedShapes))) return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); // Create the destination tensor using the above values. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1241,11 +1241,16 @@ Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { + auto tensorType = operand.getType().cast(); auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), + llvm::seq(0, tensorType.getRank()), [&](int64_t dim) -> OpFoldResult { - return builder.createOrFold(loc, operand, dim); + return tensorType.isDynamicDim(dim) + ? static_cast( + builder.createOrFold(loc, operand, + dim)) + : static_cast( + builder.getIndexAttr(tensorType.getDimSize(dim))); })); shapes.emplace_back(std::move(currShape)); }