diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1073,18 +1073,20 @@ >, InterfaceMethod< /*desc=*/[{ - Return the position in the results of the affine map computed - by getLoopsToShapesMap() that represents the shape of the - result value at a dimension. + Return the range of position in the result of the affine map + computed by getLoopsToShapesMap() which correspond to the + AffineExprs used to access the outputs of the operation. }], - /*retTy=*/"Optional", - /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap", - /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim), + /*retTy=*/"std::pair", + /*methodName=*/"getResultsPositionInLoopsToShapeMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (resultIdx >= getNumOutputs()) return {}; - return getOperandDimPositionInLoopsToShapeMap( - getNumInputs() + resultIdx, dim); + return + {*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0), + (*getOperandDimPositionInLoopsToShapeMap + (getNumInputs() + getNumOutputs() - 1, + getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1}; }] >, @@ -1153,8 +1155,8 @@ /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible - Value reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, - int64_t dim); + LogicalResult reifyReturnTypeShapes(OpBuilder &b, + SmallVectorImpl> &reifiedReturnShapes); //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -37,7 +37,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + ["reifyReturnTypeShapes"]>]> { let summary = "operation to define a tensor of particular value"; let description = [{ @@ -132,7 +132,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + ["reifyReturnTypeShapes"]>, NoSideEffect]> { let summary = "tensor pad operation"; let description = [{ @@ -422,7 +422,7 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< "tensor_reshape", [DeclareOpInterfaceMethods]>, + ["reifyReturnTypeShapes"]>]>, Arguments<(ins AnyTensor:$src, AffineMapArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,10 +35,10 @@ return isa(this->getOperation()) ? getNumLoops() : 0; } - Value reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, - int64_t dim) { - return cast(getOperation()). - reifyReturnTypeShapeForResult(b, resultIdx, dim); + LogicalResult reifyReturnTypeShapes(OpBuilder &b, + SmallVectorImpl> &reifiedReturnShapes) { + return cast(getOperation()).reifyReturnTypeShapes(b, + reifiedReturnShapes); } }]; } diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -100,79 +100,26 @@ "::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&": $inferredReturnShapes), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return failure(); }] - >, - InterfaceMethod< - /*desc=*/[{Reify the shape computation for the operation. - - Insert operations using the given OpBuilder that computes the result - shape along a particular dimension of result at `resultIndex`. - }], - /*retTy=*/"::mlir::Value", - /*methodName=*/"reifyReturnTypeShapeForResult", - /*args=*/(ins "::mlir::OpBuilder&":$builder, - "unsigned":$resultIndex, - "int64_t":$dim), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - if (resultIndex == 0) return $_op.reifyReturnTypeShape(builder, dim); - return nullptr; }] + /*defaultImplementation=*/[{ return ::mlir::failure(); }] >, InterfaceMethod< /*desc=*/[{Reify the shape computation for the operation. Insert operations using the given OpBuilder that computes the - result shape along a particular dimension. Valid only when the - operation has a single result value. - }], - /*retTy=*/"::mlir::Value", - /*methodName=*/"reifyReturnTypeShape", - /*args=*/(ins "::mlir::OpBuilder&":$builder, - "int64_t":$dim), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return nullptr; }] - >, - InterfaceMethod< - /*desc=*/[{Reify the shape computation for the operation. - - Insert operations using the given OpBuilder that computes the result - shape for a result at `resultIndex`. - }], - /*retTy=*/"::mlir::LogicalResult", - /*methodName=*/"reifyReturnTypeShapesForResult", - /*args=*/(ins "::mlir::OpBuilder&":$builder, - "unsigned":$resultIndex, - "::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - if (resultIndex == 0) - return $_op.reifyReturnTypeShapes(builder, reifiedReturnShapes); - return ::mlir::failure(); - }] - >, - InterfaceMethod< - /*desc=*/[{Reify the shape computation for the operation. - - Insert operations using the given OpBuilder that computes the result - shape. Valid only when the operation has a single result value. + result shape. The `reifiedReturnShapes` is expected to be + populated with as many vectors as the number of results of the + op (empty if the shape of a result value cannot be computed). If + the returned shape for a result is not empty, its size must + match the rank of the shaped type returned. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"reifyReturnTypeShapes", /*args=*/(ins "::mlir::OpBuilder&":$builder, - "::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes), + "::mlir::SmallVectorImpl>&" + :$reifiedReturnShapes), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - ShapedType type = - $_op.getOperation()->getResult(0).getType().template dyn_cast(); - if (!type) return failure(); - for (auto dim : llvm::seq(0, type.getRank())) { - Value shape = $_op.reifyReturnTypeShape(builder, dim); - if (!shape) return failure(); - reifiedReturnShapes.push_back(shape); - } - return success(); - }] - >, + /*defaultImplementation=*/[{ return ::mlir::failure(); }] + > ]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -233,56 +233,57 @@ llvm::SmallSet positions; }; -Value LinalgOp::reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, - int64_t dim) { +LogicalResult LinalgOp::reifyReturnTypeShapes( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { // An example that helps understand the logic below. // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) // We want to express the shape of dim 0 of O in terms of shape of the inputs. // This is achieved as follows. // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) - // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) + // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) - // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) - // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) + // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) AffineMap loopsToShapesMap = getLoopsToShapesMap(); // Find the position in the above map that represents the shape of the // result:dim being inferred. - Optional resultDimSubMapPos = - getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); - if (!resultDimSubMapPos) - return {}; + auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(); /// From loopsToShapesMap extract the submap that represents the shape of the - /// (resultIdx, dim) needed - AffineMap loopToResultDimShapeMap = - loopsToShapesMap.getSubMap(*resultDimSubMapPos); - AffineMap operandShapesToResultDimMap = - loopToResultDimShapeMap.compose(getShapesToLoopsMap()); + /// (resultIdx, dim) needed. + SmallVector resultPosRange = + llvm::to_vector<4>(llvm::seq(resultShapesSubMapPos.first, + resultShapesSubMapPos.second)); + AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange); + AffineMap resultShapesFromInputShapesMap = + loopToResultsShapeMap.compose(getShapesToLoopsMap()); // Check that the result dim map does not contain the positions corresponding // to the outputs. llvm::SmallSet outputDims; - unsigned outputDimPosStart = - getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); - unsigned outputDimPosEnd = - getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, - getOutputOpOperands() - .back() - .get() - .getType() - .cast() - .getRank() - - 1) - .getValue(); - llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), + llvm::for_each(resultPosRange, [&outputDims](unsigned dim) { outputDims.insert(dim); }); HasAffineDimExprVisitor checkDimExpr(outputDims); Location loc = getOperation()->getLoc(); - if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) - return b.createOrFold(loc, getOutput(resultIdx), dim); - return applyMapToValues(b, loc, operandShapesToResultDimMap, - createFlatListOfOperandDims(b, loc))[0]; + auto allResultDimValues = + applyMapToValues(b, loc, resultShapesFromInputShapesMap, + createFlatListOfOperandDims(b, loc)); + unsigned pos = 0; + ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); + for (auto resultIdx : llvm::seq(0, getNumOutputs())) { + ShapedType resultType = getOutputShapedType(resultIdx); + SmallVector shapes; + for (unsigned dim : llvm::seq(0, resultType.getRank())) { + if (checkDimExpr.visit(shapeExprs[pos])) + shapes.push_back(b.create(loc, getOutput(resultIdx), dim)); + else + shapes.push_back(allResultDimValues[pos]); + pos++; + } + reifiedReturnShapes.emplace_back(std::move(shapes)); + } + return success(); } LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -804,12 +804,12 @@ if (!reshapeOp.src().getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); - SmallVector resultShapeValues; - if (failed(reshapeOp.reifyReturnTypeShapesForResult(rewriter, 0, - resultShapeValues))) + SmallVector, 4> resultShapes; + if (failed(reshapeOp.reifyReturnTypeShapes(rewriter, resultShapes)) || + !llvm::hasSingleElement(resultShapes)) return failure(); Value initTensor = rewriter.create( - loc, resultShapeValues, reshapeOp.getResultType().getElementType()); + loc, resultShapes[0], reshapeOp.getResultType().getElementType()); rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), initTensor); return success(); @@ -824,10 +824,18 @@ context); } -Value InitTensorOp::reifyReturnTypeShape(OpBuilder &builder, int64_t dim) { - if (isDynamicSize(dim)) - return getDynamicSize(dim); - return builder.create(getLoc(), getStaticSize(dim)); +LogicalResult InitTensorOp::reifyReturnTypeShapes( + OpBuilder &builder, + SmallVectorImpl> &reifiedReturnShapes) { + Location loc = getLoc(); + auto shapes = llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { + if (isDynamicSize(dim)) + return getDynamicSize(dim); + return builder.create(loc, getStaticSize(dim)); + })); + reifiedReturnShapes.emplace_back(std::move(shapes)); + return success(); } //===----------------------------------------------------------------------===// @@ -979,20 +987,31 @@ builder); } -Value PadTensorOp::reifyReturnTypeShape(OpBuilder &b, int64_t dim) { +LogicalResult PadTensorOp::reifyReturnTypeShapes( + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { Location loc = getLoc(); - auto getAsValue = [&](OpFoldResult valueOrAttr) -> Value { - if (Attribute attr = valueOrAttr.dyn_cast()) - return b.create(loc, attr.cast().getInt()); - return valueOrAttr.get(); + auto getAsValues = [&](ArrayRef foldResults) { + return llvm::to_vector<4>( + llvm::map_range(foldResults, [&](OpFoldResult foldResult) -> Value { + if (Attribute attr = foldResult.dyn_cast()) + return b.create(loc, + attr.cast().getInt()); + return foldResult.get(); + })); }; - auto lowPad = getAsValue(getMixedLowPad()[dim]); - auto highPad = getAsValue(getMixedHighPad()[dim]); - Value sourceDim = b.create(loc, source(), dim); - AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) + - b.getAffineSymbolExpr(1); - return applyMapToValues(b, loc, AffineMap::get(1, 2, expr), - {sourceDim, lowPad, highPad})[0]; + auto lowPad = getAsValues(getMixedLowPad()); + auto highPad = getAsValues(getMixedHighPad()); + auto shape = llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, getSourceType().getRank()), + [&](int64_t dim) -> Value { + Value sourceDim = b.create(loc, source(), dim); + AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) + + b.getAffineSymbolExpr(1); + return applyMapToValues(b, loc, AffineMap::get(1, 2, expr), + {sourceDim, lowPad[dim], highPad[dim]})[0]; + })); + reifiedReturnShapes.emplace_back(std::move(shape)); + return success(); } //===----------------------------------------------------------------------===// @@ -1642,17 +1661,11 @@ context); } -Value TensorReshapeOp::reifyReturnTypeShape(OpBuilder &b, int64_t dim) { - return getReshapeOutputDimFromInputShape(b, getLoc(), dim, src(), - getResultType().getShape(), - getReassociationMaps()); -} - LogicalResult TensorReshapeOp::reifyReturnTypeShapes( - OpBuilder &b, SmallVectorImpl &reifiedReturnShapes) { + OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { auto resultShape = getReshapeOutputShapeFromInputShape( b, getLoc(), src(), getResultType().getShape(), getReassociationMaps()); - reifiedReturnShapes.append(resultShape.begin(), resultShape.end()); + reifiedReturnShapes.emplace_back(std::move(resultShape)); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1576,12 +1576,16 @@ if (!dimIndex) return failure(); - Value dimReplacement = shapedTypeOp.reifyReturnTypeShapeForResult( - rewriter, dimValue.getResultNumber(), *dimIndex); - if (!dimReplacement) + SmallVector> returnShapes; + if (failed(shapedTypeOp.reifyReturnTypeShapes(rewriter, returnShapes))) return failure(); - rewriter.replaceOp(dimOp, dimReplacement); + unsigned resultNumber = dimValue.getResultNumber(); + if (returnShapes.size() <= resultNumber || + returnShapes[resultNumber].size() <= static_cast(*dimIndex)) + return failure(); + + rewriter.replaceOp(dimOp, returnShapes[resultNumber][*dimIndex]); return success(); } }; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -678,9 +678,10 @@ } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, llvm::SmallVectorImpl &shapes) { - shapes = SmallVector{ - builder.createOrFold(getLoc(), getOperand(0), 0)}; + OpBuilder &builder, + llvm::SmallVectorImpl> &shapes) { + shapes.push_back(SmallVector{ + builder.createOrFold(getLoc(), getOperand(0), 0)}); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -126,10 +126,11 @@ // Use permutations of 2 args as operands. auto shapedOp = cast(op); - SmallVector shapes; - if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) + SmallVector, 2> shapes; + if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) || + !llvm::hasSingleElement(shapes)) return; - for (auto it : llvm::enumerate(shapes)) + for (auto it : llvm::enumerate(shapes[0])) op->emitRemark() << "value " << it.index() << ": " << it.value().getDefiningOp(); }