diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -222,6 +222,10 @@ Optional = None); }; +/// Given an operation, retrieves the value of each dynamic dimension through +/// constructing the necessary DimOp operators. +SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -142,6 +142,10 @@ void createAffineComputationSlice(Operation *opInst, SmallVectorImpl *sliceOps); +/// Given an operation, retrieves the value of each dynamic dimension through +/// constructing the necessary DimOp operators. +SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); + } // end namespace mlir #endif // MLIR_TRANSFORMS_UTILS_H diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -21,18 +21,6 @@ using namespace ::mlir; using namespace ::mlir::linalg; -static SmallVector getDynOperands(Location loc, Value val, - OpBuilder &b) { - SmallVector dynOperands; - auto shapedType = val.getType().cast(); - for (auto dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == TensorType::kDynamicSize) { - dynOperands.push_back(b.create(loc, val, dim.index())); - } - } - return dynOperands; -} - static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = memref.getType().cast(); auto alloc = diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -10,6 +10,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" @@ -62,18 +63,9 @@ // Extract static / dynamic shape mix from the first operand. Value firstOperand = operands.front(); auto rankedTensorType = t.cast(); - SmallVector dynamicShape; - SmallVector staticShape; - dynamicShape.reserve(rankedTensorType.getRank()); - staticShape.reserve(rankedTensorType.getRank()); - unsigned idx = 0; - for (auto shape : rankedTensorType.getShape()) { - staticShape.push_back(shape); - if (rankedTensorType.isDynamicDim(idx)) - dynamicShape.push_back(b.create(loc, firstOperand, idx)); - ++idx; - } - // Create init tensor. + auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); + auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); + res.push_back(b.create( loc, dynamicShape, staticShape, rankedTensorType.getElementType())); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -366,5 +366,16 @@ assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } +SmallVector getDynOperands(Location loc, Value val, OpBuilder &b) { + SmallVector dynOperands; + auto shapedType = val.getType().cast(); + for (auto dim : llvm::enumerate(shapedType.getShape())) { + if (dim.value() == TensorType::kDynamicSize) { + dynOperands.push_back(b.create(loc, val, dim.index())); + } + } + return dynOperands; +} + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -53,6 +53,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -394,13 +395,8 @@ // Extract information about dynamically shaped types by // extracting their dynamic dimensions. - SmallVector dynamicOperands; - for (auto shapeElement : llvm::enumerate(memRefType.getShape())) { - if (!ShapedType::isDynamic(shapeElement.value())) - continue; - dynamicOperands.push_back(builder.create( - terminator->getLoc(), sourceValue, shapeElement.index())); - } + auto dynamicOperands = + linalg::getDynOperands(terminator->getLoc(), sourceValue, builder); // TODO: provide a generic interface to create dialect-specific // Alloc and CopyOp nodes. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Utils.h" @@ -83,13 +84,9 @@ // The double buffer is allocated right before 'forOp'. OpBuilder bOuter(forOp); // Put together alloc operands for any dynamic dimensions of the memref. - SmallVector allocOperands; - unsigned dynamicDimCount = 0; - for (auto dimSize : oldMemRefType.getShape()) { - if (dimSize == -1) - allocOperands.push_back( - bOuter.create(forOp.getLoc(), oldMemRef, dynamicDimCount++)); - } + + auto allocOperands = + linalg::getDynOperands(forOp.getLoc(), oldMemRef, bOuter); // Create and place the alloc right before the 'affine.for' operation. Value newMemRef =