diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -33,11 +33,6 @@ namespace mlir { namespace bufferization { -/// Populate `dynamicDims` with tensor::DimOp / memref::DimOp results for all -/// dynamic dimensions of the given shaped value. -void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, - SmallVector &dynamicDims); - /// Try to cast the given ranked MemRef-typed value to the given ranked MemRef /// type. Insert a reallocation + copy if it cannot be statically guaranteed /// that a direct cast would be valid. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -45,10 +45,6 @@ /// Check if iterator type has "reduction" semantics. bool isReductionIterator(utils::IteratorType iteratorType); -/// Given an operation, retrieves the value of each dynamic dimension through -/// constructing the necessary DimOp operators. -SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); - /// Computes an upper bound for the result `value` of an index computation. /// Translates AffineMinOps and AffineApplyOps along the use-def chains of the /// index computation to affine constraints and projects out intermediate diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -55,10 +55,6 @@ /// single deallocate if it exists or nullptr. std::optional findDealloc(Value allocValue); -/// Return the dimensions of the given memref value. -SmallVector getMixedSizes(OpBuilder &builder, Location loc, - Value value); - /// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and /// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref` /// to that of `targetShape`. diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -112,10 +112,6 @@ /// that can be folded. LogicalResult foldTensorCast(Operation *op); -/// Return the dimensions of the given tensor value. -SmallVector getMixedSizes(OpBuilder &builder, Location loc, - Value value); - /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor` /// to that of `targetType`. diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -21,16 +21,6 @@ PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder); -// Creates dim ops for each dynamic dimension of the ranked tensor argument and -// returns these as values. -SmallVector createDynamicDimValues(OpBuilder &b, Location loc, - Value rankedTensor); - -// Creates dim ops or constant ops for each dimension of the ranked tensor -// argument and returns these as values. -SmallVector createDimValues(OpBuilder &b, Location loc, - Value rankedTensor); - /// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty. /// Fail if `transposeVector` is not a permutation matching the tensor rank. FailureOr diff --git a/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h --- a/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h +++ b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h @@ -14,6 +14,7 @@ #define MLIR_INTERFACES_SHAPEDTYPEINTERFACES_H_ #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SmallVector.h" /// Include the generated interface declarations. #include "mlir/Interfaces/ShapedTypeInterfaces.h.inc" @@ -32,6 +33,17 @@ OpFoldResult reifyShapeDim(OpBuilder &builder, Location loc, Value shapedValue, int64_t dim); +/// Reify all dimensions of the given shaped value. The shape value must have a +/// statically known rank. This function returns an IntegerAttr for each static +/// dimension size. Otherwise, it returns a Value. +SmallVector reifyShapeDims(OpBuilder &builder, Location loc, + Value shapedValue); + +/// Reify all dynamic dimensions of the given shaped value. The shape value must +/// have a statically known rank. +SmallVector reifyDynamicShapeDims(OpBuilder &builder, Location loc, + Value shapedValue); + } // namespace mlir #endif // MLIR_INTERFACES_SHAPEDTYPEINTERFACES_H_ diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -152,7 +153,7 @@ // If the shape could not be reified, create DimOps. if (!reifiedShapes) - populateDynamicDimSizes(b, loc, tensor, dynamicSizes); + dynamicSizes = reifyDynamicShapeDims(b, loc, tensor); } // Create AllocTensorOp. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include using namespace mlir; @@ -129,22 +130,6 @@ return success(); } -void mlir::bufferization::populateDynamicDimSizes( - OpBuilder &b, Location loc, Value shapedValue, - SmallVector &dynamicDims) { - auto shapedType = shapedValue.getType().cast(); - for (int64_t i = 0; i < shapedType.getRank(); ++i) { - if (shapedType.isDynamicDim(i)) { - if (shapedType.isa()) { - dynamicDims.push_back(b.create(loc, shapedValue, i)); - } else { - assert(shapedType.isa() && "expected tensor"); - dynamicDims.push_back(b.create(loc, shapedValue, i)); - } - } - } -} - //===----------------------------------------------------------------------===// // AllocTensorOp //===----------------------------------------------------------------------===// @@ -176,7 +161,7 @@ SmallVector dynamicDims = getDynamicSizes(); if (getCopy()) { assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); - populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); + dynamicDims = reifyDynamicShapeDims(rewriter, loc, copyBuffer); } FailureOr alloc = options.createAlloc( rewriter, loc, allocType->cast(), dynamicDims); diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -18,6 +18,7 @@ MLIRDialect MLIRFuncDialect MLIRIR + MLIRShapedTypeInterfaces MLIRSparseTensorDialect MLIRTensorDialect MLIRMemRefDialect diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -17,6 +17,7 @@ MLIRParser MLIRPDLDialect MLIRSCFDialect + MLIRShapedTypeInterfaces MLIRSideEffectInterfaces MLIRTransformDialect MLIRTransformDialectUtils diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1006,7 +1007,7 @@ auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), SmallVector(destRank, zero), - tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + reifyShapeDims(rewriter, loc, unPackOp->getResult(0)), SmallVector(destRank, one)); // 7. Replace unPackOp by transposeOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" @@ -186,7 +187,7 @@ // Create memref.tensor_store. SmallVector sizes = - getMixedSizes(rewriter, loc, padOp.getSource()); + reifyShapeDims(rewriter, loc, padOp.getSource()); SmallVector strides(padOp.getResultType().getRank(), rewriter.getIndexAttr(1)); Value subview = rewriter.create( @@ -319,7 +320,7 @@ // Create tensor::InsertSliceOp. SmallVector sliceSizes = - getMixedSizes(rewriter, loc, padOp.getSource()); + reifyShapeDims(rewriter, loc, padOp.getSource()); SmallVector sliceStrides(resultType.getRank(), rewriter.getIndexAttr(1)); auto insertSliceOp = rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include @@ -481,7 +482,7 @@ if (auto unPackEmpty = unPackDest.getDefiningOp()) unPackMixedSizes = unPackEmpty.getMixedSizes(); else - unPackMixedSizes = tensor::getMixedSizes(rewriter, loc, unPackDest); + unPackMixedSizes = reifyShapeDims(rewriter, loc, unPackDest); unPackDest = rewriter.create(loc, unPackMixedSizes, genericOutElementType); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" @@ -310,7 +311,7 @@ rewriter.setInsertionPointAfterValue(op->get()); auto elemType = op->get().getType().cast().getElementType(); auto empty = rewriter.create( - loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); + loc, reifyShapeDims(rewriter, loc, op->get()), elemType); auto [start, end] = genericOp.getDpsInitsPositionRange(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); @@ -442,7 +443,7 @@ unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, origOutput); + reifyShapeDims(rewriter, loc, origOutput); SmallVector strides(rank, rewriter.getIndexAttr(1)); return rewriter.createOrFold( loc, result, origOutput, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -69,7 +70,7 @@ Value firstOperand = operands.front(); auto rankedTensorType = t.cast(); auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); - auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); + auto dynamicShape = reifyDynamicShapeDims(b, loc, firstOperand); res.push_back(b.create( loc, staticShape, rankedTensorType.getElementType(), dynamicShape)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -88,18 +89,7 @@ // TODO: This insert/extract could be potentially made a utility method. unsigned resultNumber = source.cast().getResultNumber(); SmallVector offsets = padOp.getMixedLowPad(); - SmallVector sizes; - sizes.reserve(offsets.size()); - for (const auto &shape : llvm::enumerate( - source.getType().cast().getShape())) { - if (ShapedType::isDynamic(shape.value())) { - sizes.push_back( - rewriter.create(loc, source, shape.index()) - .getResult()); - } else { - sizes.push_back(rewriter.getIndexAttr(shape.value())); - } - } + SmallVector sizes = reifyShapeDims(rewriter, loc, source); SmallVector strides(offsets.size(), rewriter.getIndexAttr(1)); auto slice = rewriter.create( loc, fillTensor.getResult(0), offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" using namespace mlir; using namespace mlir::linalg; @@ -311,8 +312,7 @@ auto t = rankedTensor.getType().cast(); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( reductionDimSize / splitFactor, insertSplitDimension); - SmallVector dims = - tensor::createDynamicDimValues(b, loc, rankedTensor); + SmallVector dims = reifyDynamicShapeDims(b, loc, rankedTensor); Value emptyOrAllocTensor; if (useAlloc) { emptyOrAllocTensor = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -916,15 +917,8 @@ // for copying the PadOp source. auto sourceType = padOp.getSourceType(); // Compute size of source of tensor::PadOp. - SmallVector srcSizes; - for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { - if (sourceType.isDynamicDim(dim)) { - srcSizes.push_back(rewriter.createOrFold( - padOp.getLoc(), padOp.getSource(), dim)); - } else { - srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); - } - } + SmallVector srcSizes = + reifyShapeDims(rewriter, padOp.getLoc(), padOp.getSource()); // Strides of InsertSliceOp are all 1. SmallVector strides(sourceType.getRank(), rewriter.getIndexAttr(1)); @@ -1131,10 +1125,8 @@ int numLoops = transpShape.size(); SmallVector tileStrides(numLoops, oneIdxAttr); SmallVector tileOffsets(numLoops, zeroIdxAttr); - SmallVector tileSizes; - for (int dim : innerDimsPos) - tileSizes.push_back(getAsOpFoldResult( - rewriter.createOrFold(loc, unpackOp.getDest(), dim))); + SmallVector tileSizes = + reifyShapeDims(rewriter, loc, unpackOp.getDest()); applyPermutationToVector(tileSizes, perm); auto partialTile = rewriter.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/Sequence.h" @@ -31,8 +32,8 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include #include +#include using namespace mlir; using namespace mlir::linalg; @@ -252,11 +253,9 @@ operandDimPos))) return failure(); - Value dynamicDim = linalgOp.hasTensorSemantics() - ? (Value)rewriter.create( - linalgOp.getLoc(), operand, operandDimPos) - : (Value)rewriter.create( - linalgOp.getLoc(), operand, operandDimPos); + Value dynamicDim = getValueOrCreateConstantIndexOp( + rewriter, linalgOp.getLoc(), + reifyShapeDim(rewriter, linalgOp.getLoc(), operand, operandDimPos)); iterSpaceDynamicSizes.push_back(dynamicDim); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -185,19 +185,6 @@ return iteratorType == utils::IteratorType::reduction; } -/// Given an operation, retrieves the value of each dynamic dimension through -/// constructing the necessary DimOp operators. -SmallVector getDynOperands(Location loc, Value val, OpBuilder &b) { - SmallVector dynOperands; - auto shapedType = val.getType().cast(); - for (const auto &dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == ShapedType::kDynamic) - dynOperands.push_back(getValueOrCreateConstantIndexOp( - b, loc, reifyShapeDim(b, loc, val, dim.index()))); - } - return dynOperands; -} - void getUpperBoundForIndex(Value value, AffineMap &boundMap, SmallVectorImpl &boundOperands, bool constantRequired) { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" @@ -109,21 +110,6 @@ return NoneType::get(type.getContext()); } -SmallVector memref::getMixedSizes(OpBuilder &builder, - Location loc, Value value) { - auto memrefType = value.getType().cast(); - SmallVector result; - for (int64_t i = 0; i < memrefType.getRank(); ++i) { - if (memrefType.isDynamicDim(i)) { - Value size = builder.create(loc, value, i); - result.push_back(size); - } else { - result.push_back(builder.getIndexAttr(memrefType.getDimSize(i))); - } - } - return result; -} - //===----------------------------------------------------------------------===// // Utility functions for propagating static information //===----------------------------------------------------------------------===// @@ -2969,7 +2955,7 @@ auto memrefType = memref.getType().cast(); unsigned rank = memrefType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); - SmallVector sizes = getMixedSizes(b, loc, memref); + SmallVector sizes = reifyShapeDims(b, loc, memref); SmallVector strides(rank, b.getIndexAttr(1)); auto targetType = SubViewOp::inferRankReducedResultType( targetShape, memrefType, offsets, sizes, strides) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -118,14 +118,9 @@ /// the tensor (for dynamic sizes). static void sizesForTensor(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, ShapedType stp, Value tensor) { - for (const auto &d : enumerate(stp.getShape())) { - Value dim; - if (d.value() == ShapedType::kDynamic) - dim = builder.create(loc, tensor, d.index()); - else - dim = constantIndex(builder, loc, d.value()); - sizes.push_back(dim); - } + for (int64_t dim = 0; dim < stp.getRank(); ++dim) + sizes.push_back(getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, tensor, dim))); } // TODO: The dim level property of the COO type relies on input tensors, the diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -66,6 +66,7 @@ MLIRLinalgDialect MLIRLinalgUtils MLIRSCFDialect + MLIRShapedTypeInterfaces MLIRSupport MLIRTensorDialect MLIRTensorUtils diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -46,21 +47,6 @@ return nullptr; } -SmallVector tensor::getMixedSizes(OpBuilder &builder, - Location loc, Value value) { - auto tensorType = value.getType().cast(); - SmallVector result; - for (int64_t i = 0; i < tensorType.getRank(); ++i) { - if (tensorType.isDynamicDim(i)) { - Value size = builder.create(loc, value, i); - result.push_back(size); - } else { - result.push_back(builder.getIndexAttr(tensorType.getDimSize(i))); - } - } - return result; -} - FailureOr tensor::getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult) { auto tensorType = opResult.getType().dyn_cast(); @@ -2058,7 +2044,7 @@ auto rankedTensorType = tensor.getType().cast(); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); - SmallVector sizes = getMixedSizes(b, loc, tensor); + SmallVector sizes = reifyShapeDims(b, loc, tensor); SmallVector strides(rank, b.getIndexAttr(1)); return b.createOrFold(loc, targetType, tensor, offsets, sizes, strides); @@ -2199,16 +2185,7 @@ LogicalResult InsertSliceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); - for (auto dim : llvm::seq(0, getType().getRank())) { - if (getType().isDynamicDim(dim)) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(getLoc(), getDest(), dim); - } else { - reifiedReturnShapes[0][dim] = - builder.getIndexAttr(getType().getDimSize(dim)); - } - } + reifiedReturnShapes.push_back(getMixedSizes()); return success(); } @@ -2410,7 +2387,7 @@ auto rankedTensorType = dest.getType().cast(); unsigned rank = rankedTensorType.getRank(); SmallVector offsets(rank, b.getIndexAttr(0)); - SmallVector sizes = getMixedSizes(b, loc, dest); + SmallVector sizes = reifyShapeDims(b, loc, dest); SmallVector strides(rank, b.getIndexAttr(1)); return b.createOrFold(loc, tensor, dest, offsets, sizes, strides); @@ -3153,18 +3130,8 @@ ReifiedRankedShapedTypeDims &reifiedReturnShapes) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); - int64_t destRank = op.getDestRank(); - reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = op.getResult().getType().template cast(); - for (auto dim : llvm::seq(0, destRank)) { - if (resultType.isDynamicDim(dim)) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(op.getLoc(), op.getDest(), dim); - } else { - reifiedReturnShapes[0][dim] = - builder.getIndexAttr(resultType.getDimSize(dim)); - } - } + reifiedReturnShapes.push_back( + reifyShapeDims(builder, op.getLoc(), op.getDest())); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" using namespace mlir; @@ -135,7 +136,7 @@ DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); SmallVector srcDimValues = - tensor::createDimValues(b, loc, packOp.getSource()); + reifyShapeDims(b, loc, packOp.getSource()); SmallVector inputIndices, inputSizes; for (auto dim : llvm::seq(0, inputRank)) { using AV = AffineValueExpr; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" using namespace mlir; using namespace mlir::bufferization; @@ -917,7 +918,7 @@ // Create tensor::InsertSliceOp. SmallVector sliceSizes = - getMixedSizes(rewriter, loc, padOp.getSource()); + reifyShapeDims(rewriter, loc, padOp.getSource()); SmallVector sliceStrides(srcType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -39,34 +39,6 @@ return b.create(loc, type, source, low, high, pad, nofold); } -SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, - Location loc, - Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); - SmallVector dynamicDims; - for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (en.value() == ShapedType::kDynamic) - dynamicDims.push_back( - b.create(loc, rankedTensor, en.index())); - } - return dynamicDims; -} - -SmallVector -mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); - SmallVector dims; - for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (ShapedType::isDynamic(en.value())) { - dims.push_back( - b.createOrFold(loc, rankedTensor, en.index())); - } else { - dims.push_back(b.getIndexAttr(en.value())); - } - } - return dims; -} - FailureOr mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector) { diff --git a/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp b/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp --- a/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp +++ b/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp @@ -25,3 +25,25 @@ Value shapedValue, int64_t dim) { return reifyShapeDim(builder, loc, shapedValue, builder.getIndexAttr(dim)); } + +SmallVector mlir::reifyShapeDims(OpBuilder &builder, Location loc, + Value shapedValue) { + auto shapedType = shapedValue.getType().cast(); + assert(shapedType.hasRank() && "expected ranked shaped value"); + return llvm::to_vector(llvm::map_range( + llvm::seq(0, shapedType.getRank()), [&](int64_t dim) { + return reifyShapeDim(builder, loc, shapedValue, dim); + })); +} + +SmallVector mlir::reifyDynamicShapeDims(OpBuilder &builder, Location loc, + Value shapedValue) { + SmallVector result; + auto shapedType = shapedValue.getType().cast(); + assert(shapedType.hasRank() && "expected ranked shaped value"); + for (const auto &dim : llvm::enumerate(shapedType.getShape())) + if (dim.value() == ShapedType::kDynamic) + result.push_back( + reifyShapeDim(builder, loc, shapedValue, dim.index()).get()); + return result; +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5535,6 +5535,7 @@ ":LinalgDialect", ":LinalgUtils", ":SCFDialect", + ":ShapedTypeInterfaces", ":TensorDialect", ":TensorUtils", ":TilingInterface", @@ -8443,6 +8444,7 @@ ":PDLDialect", ":Parser", ":SCFTransforms", + ":ShapedTypeInterfaces", ":SideEffectInterfaces", ":Support", ":TensorDialect", @@ -10391,6 +10393,7 @@ ":IR", ":InferTypeOpInterface", ":MemRefDialect", + ":ShapedTypeInterfaces", ":SparseTensorDialect", ":Support", ":TensorDialect",