diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -928,8 +928,8 @@ /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible - LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b, - SmallVectorImpl> &reifiedReturnShapes); + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -36,8 +36,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "operation to define a tensor of particular value"; let description = [{ @@ -130,10 +129,8 @@ } def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - NoSideEffect]> { + [AttrSizedOperandSegments, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "tensor pad operation"; let description = [{ `linalg.pad_tensor` is an operation that pads the `source` tensor @@ -398,8 +395,7 @@ class Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< mnemonic, - [DeclareOpInterfaceMethods]>, + [DeclareOpInterfaceMethods]>, Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -26,7 +26,7 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface])> { code structuredOpsBaseDecls = [{ // Return whether the op accesses the iteration indices. bool hasIndexSemantics() { @@ -36,9 +36,9 @@ return !op->getRegion(0).front().getOps().empty(); } - LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b, - SmallVectorImpl> &reifiedReturnShapes) { - return cast(getOperation()).reifyReturnTypeShapesPerResultDim(b, + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return cast(getOperation()).reifyResultShapes(b, reifiedReturnShapes); } }]; diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -35,6 +35,13 @@ /// into `patterns`. void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); +/// Appends patterns that resolve `memref.dim` operations with values that are +/// defined by operations that implement the +/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input +/// operands. +void populateResolveRankedShapeTypeResultDimsPatterns( + RewritePatternSet &patterns); + /// Appends patterns that resolve `memref.dim` operations with values that are /// defined by operations that implement the `InferShapedTypeOpInterface`, in /// terms of shapes of its input operands. @@ -50,7 +57,14 @@ /// Creates an operation pass to resolve `memref.dim` operations with values /// that are defined by operations that implement the -/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands. +/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input +/// operands. +std::unique_ptr createResolveRankedShapeTypeResultDimsPass(); + +/// Creates an operation pass to resolve `memref.dim` operations with values +/// that are defined by operations that implement the +/// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`, +/// in terms of shapes of its input operands. std::unique_ptr createResolveShapedTypeResultDimsPass(); //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -23,12 +23,28 @@ ]; } +def ResolveRankedShapeTypeResultDims : + Pass<"resolve-ranked-shaped-type-result-dims"> { + let summary = "Resolve memref.dim of result values of ranked shape type"; + let description = [{ + The pass resolves memref.dim of result of operations that + implement the `ReifyRankedShapedTypeOpInterface` in terms of + shapes of its operands. + }]; + let constructor = + "mlir::memref::createResolveRankedShapeTypeResultDimsPass()"; + let dependentDialects = [ + "memref::MemRefDialect", "tensor::TensorDialect" + ]; +} + def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> { let summary = "Resolve memref.dim of result values"; let description = [{ The pass resolves memref.dim of result of operations that - implement the `InferShapedTypeOpInterface` in terms of shapes of - its operands. + implement the `InferShapedTypeOpInterface` or + `ReifyRankedShapedTypeOpInterface` in terms of shapes of its + operands. }]; let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()"; let dependentDialects = [ diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -432,8 +432,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides< Tensor_Dialect, "insert_slice", [NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, TypesMatchWith<"expected result type to match dest type", "dest", "result", "$_self">]> { let summary = "insert_slice operation"; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -23,6 +23,8 @@ namespace mlir { +using ReifiedRankedShapedTypeDims = SmallVector>; + /// ShapedTypeComponents that represents the components of a ShapedType. /// The components consist of /// - A ranked or unranked shape with the dimension specification match those diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -105,9 +105,7 @@ /*desc=*/[{Reify the shape computation for the operation. Insert operations using the given OpBuilder that computes the - result shape. Only one of this method or - `reifyReturnTypeShapesPerResultDim` needs to be overriden by the - operation. This interface is supposed to be workable during dialect + result shape. This interface is supposed to be workable during dialect conversion (e.g. convert from tensor world to buffer world), where `getOperand` may be invalid. For example, some ops (e.g. dynamic_reshape(input, target_shape)) may depend on their operands @@ -127,34 +125,6 @@ "::mlir::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes), /*methodBody=*/[{}], /*defaultImplementation=*/[{ return ::mlir::failure(); }] - >, - InterfaceMethod< - /*desc=*/[{Reify the shape computation for the operation. - - Insert operations using the given OpBuilder that computes the - result shape. The `reifiedReturnShapes` is expected to be - populated with as many vectors as the number of results of the - op (empty if the shape of a result value cannot be computed). If - the returned shape for a result is not empty, its size must - match the rank of the shaped type returned. Consequently, this - interface can only be overridden if the return types are ranked. - - If both this method and `reifyReturnTypeShapes` are overridden - by the operation, `reifyReturnTypeShapes` takes precedence. This - method is intended to be used when the shape of each result, dim - pair can be computed independently. Using this method avoids - adding additional instructions to aggregate individual dimension - of a result shape into an single `Value` (and consequently - avoids the need to extract the value from the shape on the - client side). - }], - /*retTy=*/"::mlir::LogicalResult", - /*methodName=*/"reifyReturnTypeShapesPerResultDim", - /*args=*/(ins "::mlir::OpBuilder&":$builder, - "::mlir::SmallVectorImpl<::mlir::SmallVector<::mlir::Value>>&" - :$reifiedReturnShapes), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return ::mlir::failure(); }] > ]; } @@ -176,4 +146,35 @@ defvar InferTensorTypeWithReify = InferTensorType<[ "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; + +def ReifyRankedShapedTypeOpInterface : + OpInterface<"ReifyRankedShapedTypeOpInterface"> { + let description = [{ + Interface to compute the shape of the result of an operation when + the result is a ranked shape type, i.e. `RankedTensorType` or + `MemRefType`. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Reify the shape of the result of an operation (typically in + terms of shape of its operands) + + Insert operations using the given `OpBuilder` that computes + the result shape. The `reifiedReturnShapes` is expected to be + populated with as many vectors as the number of results of the + op. Each of these vectors is expected to be of size equal to + rank of the corresponding result. If the shape of a particular + result cannot be computed it must be empty. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"reifyResultShapes", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "ReifiedRankedShapedTypeDims &":$reifiedReturnShapes) + > + ]; +} + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -274,8 +274,9 @@ llvm::SmallSet positions; }; -LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult +LinalgOp::reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { // An example that helps understand the logic below. // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) // We want to express the shape of dim 0 of O in terms of shape of the inputs. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -779,9 +779,8 @@ if (!reshapeOp.src().template getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); - SmallVector, 4> resultShapes; - if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter, - resultShapes)) || + ReifiedRankedShapedTypeDims resultShapes; + if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); Value initTensor = rewriter.create( @@ -825,9 +824,8 @@ ReplaceStaticShapeDims>(context); } -LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &builder, - SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult InitTensorOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto shapes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { if (isDynamicSize(dim)) @@ -1003,8 +1001,8 @@ builder); } -LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult PadTensorOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { Location loc = getLoc(); auto lowPad = getMixedLowPad(); auto highPad = getMixedHighPad(); @@ -1429,8 +1427,8 @@ FoldReshapeWithConstant>(context); } -LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult TensorExpandShapeOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto resultShape = getAsValues(b, getLoc(), getReshapeOutputShapeFromInputShape( @@ -1440,8 +1438,8 @@ return success(); } -LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult TensorCollapseShapeOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto resultShape = getAsValues(b, getLoc(), getReshapeOutputShapeFromInputShape( diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -1,5 +1,4 @@ -//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values -//-------===// +//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -21,52 +20,6 @@ using namespace mlir; -/// Helper method to get the `Value` that is the shape of the `resultIdx`-th -/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. -/// TODO(ravishankarm): This is better put as a interface utility method -/// somewhere, but that would imply the interface will depend on the `tensor` -/// dialect. Ideally maybe a utility method in the `tensor` dialect. -static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, - int64_t dimIndex) { - unsigned resultNumber = result.getResultNumber(); - auto shapedTypeOp = dyn_cast(result.getOwner()); - Location loc = result.getOwner()->getLoc(); - if (!shapedTypeOp) - return nullptr; - - // The interface exposes two methods, one that returns the shape of all the - // results as `Value` and other that returns the shape as a list of - // `SmallVector`. The former takes precedence over the latter. So first - // check if the op implements the first interface method or the second, and - // get the value to use appropriately. - SmallVector reifiedResultShapes; - if (succeeded(shapedTypeOp.reifyReturnTypeShapes( - builder, result.getOwner()->getOperands(), reifiedResultShapes))) { - if (reifiedResultShapes.size() <= resultNumber) - return nullptr; - Value resultShape = reifiedResultShapes[resultNumber]; - auto resultShapeType = resultShape.getType().dyn_cast(); - if (!resultShapeType || !resultShapeType.getElementType().isa()) - return nullptr; - return builder.create( - loc, resultShape, builder.createOrFold(loc, dimIndex)); - } - - SmallVector> reifiedResultShapesPerDim; - if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( - builder, reifiedResultShapesPerDim))) - return nullptr; - if (reifiedResultShapesPerDim.size() <= resultNumber || - reifiedResultShapesPerDim[resultNumber].size() != - static_cast(result.getType().cast().getRank())) - return nullptr; - OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; - if (auto attr = valueOrAttr.dyn_cast()) - return builder.createOrFold( - loc, attr.cast().getInt()); - return valueOrAttr.get(); -} - namespace { /// Fold dim of an operation that implements the InferShapedTypeOpInterface template @@ -86,11 +39,62 @@ Optional dimIndex = dimOp.getConstantIndex(); if (!dimIndex) return failure(); - Value replacement = - getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); - if (!replacement) + + SmallVector reifiedResultShapes; + if (failed(shapedTypeOp.reifyReturnTypeShapes( + rewriter, shapedTypeOp->getOperands(), reifiedResultShapes))) + return failure(); + + if (reifiedResultShapes.size() != shapedTypeOp->getNumResults()) return failure(); - rewriter.replaceOp(dimOp, replacement); + + Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; + auto resultShapeType = resultShape.getType().dyn_cast(); + if (!resultShapeType || !resultShapeType.getElementType().isa()) + return failure(); + + Location loc = dimOp->getLoc(); + rewriter.replaceOpWithNewOp( + dimOp, resultShape, + rewriter.createOrFold(loc, *dimIndex)); + return success(); + } +}; + +/// Fold dim of an operation that implements the InferShapedTypeOpInterface +template +struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const override { + OpResult dimValue = dimOp.source().template dyn_cast(); + if (!dimValue) + return failure(); + auto rankedShapeTypeOp = + dyn_cast(dimValue.getOwner()); + if (!rankedShapeTypeOp) + return failure(); + + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + + SmallVector> reifiedResultShapes; + if (failed( + rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes))) + return failure(); + + if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults()) + return failure(); + + unsigned resultNumber = dimValue.getResultNumber(); + auto sourceType = dimValue.getType().dyn_cast(); + if (reifiedResultShapes[resultNumber].size() != + static_cast(sourceType.getRank())) + return failure(); + + rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]); return success(); } }; @@ -104,12 +108,26 @@ #define GEN_PASS_CLASSES #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +struct ResolveRankedShapeTypeResultDimsPass final + : public ResolveRankedShapeTypeResultDimsBase< + ResolveRankedShapeTypeResultDimsPass> { + void runOnOperation() override; +}; + struct ResolveShapedTypeResultDimsPass final : public ResolveShapedTypeResultDimsBase { void runOnOperation() override; }; + } // namespace +void memref::populateResolveRankedShapeTypeResultDimsPatterns( + RewritePatternSet &patterns) { + patterns.add, + DimOfReifyRankedShapedTypeOpInterface>( + patterns.getContext()); +} + void memref::populateResolveShapedTypeResultDimsPatterns( RewritePatternSet &patterns) { // TODO: Move tensor::DimOp pattern to the Tensor dialect. @@ -118,8 +136,17 @@ patterns.getContext()); } +void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)))) + return signalPassFailure(); +} + void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), std::move(patterns)))) @@ -129,3 +156,7 @@ std::unique_ptr memref::createResolveShapedTypeResultDimsPass() { return std::make_unique(); } + +std::unique_ptr memref::createResolveRankedShapeTypeResultDimsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1042,9 +1042,8 @@ return OpFoldResult(); } -LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim( - OpBuilder &builder, - SmallVectorImpl> &reifiedReturnShapes) { +LogicalResult InsertSliceOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); for (auto dim : llvm::seq(0, getType().getRank())) { reifiedReturnShapes[0][dim] = diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir --- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -55,34 +55,3 @@ // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] - -// ----- - -func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) - -> (index, index, index, index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1) - : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) - %1 = tensor.dim %0#0, %c0 : tensor - %2 = tensor.dim %0#0, %c1 : tensor - %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> - %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> - %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> - return %1, %2, %3, %4, %5 : index, index, index, index, index -} -// CHECK-LABEL: func @result_shape_and_per_dim( -// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> -// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C2:.+]] = constant 2 : index -// CHECK-DAG: %[[C3:.+]] = constant 3 : index -// CHECK-DAG: %[[C5:.+]] = constant 5 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] -// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]] -// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] -// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]] -// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]] -// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]] diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -822,46 +822,8 @@ return success(); } -LogicalResult -OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( - OpBuilder &builder, - llvm::SmallVectorImpl> &shapes) { - Location loc = getLoc(); - shapes.reserve(getNumOperands()); - for (Value operand : llvm::reverse(getOperands())) { - auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), - [&](int64_t dim) -> Value { - return builder.createOrFold(loc, operand, dim); - })); - shapes.emplace_back(std::move(currShape)); - } - return success(); -} - -LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - Location loc = getLoc(); - shapes.reserve(operands.size()); - for (Value operand : llvm::reverse(operands)) { - auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), - [&](int64_t dim) -> Value { - return builder.createOrFold(loc, operand, dim); - })); - shapes.push_back(builder.create( - getLoc(), builder.getIndexType(), currShape)); - } - return success(); -} - -LogicalResult -OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim( - OpBuilder &builder, - llvm::SmallVectorImpl> &shapes) { +LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -579,16 +579,7 @@ def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface", - [DeclareOpInterfaceMethods]> { - let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); - let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); -} - -def OpWithResultShapeAndPerDimInterfaceOp : - TEST_Op<"op_with_result_shape_and_per_dim_interface", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods]> { let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2046,6 +2046,7 @@ ":Affine", ":DialectUtils", ":IR", + ":InferTypeOpInterface", ":LinalgInterfacesIncGen", ":LinalgStructuredOpsIncGen", ":MemRefDialect",