diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1409,7 +1409,7 @@ ``` }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor], + let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], "any tensor or memref type">:$memrefOrTensor, Index:$index); let results = (outs Index:$result); @@ -2024,16 +2024,18 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> { let summary = "rank operation"; let description = [{ - The `rank` operation takes a tensor operand and returns its rank. + The `rank` operation takes a memref/tensor operand and returns its rank. Example: ```mlir - %1 = rank %0 : tensor<*xf32> + %1 = rank %arg0 : tensor<*xf32> + %2 = rank %arg1 : memref<*xf32> ``` }]; - let arguments = (ins AnyTensor); + let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], + "any tensor or memref type">:$memrefOrTensor); let results = (outs Index); let verifier = ?; @@ -2044,7 +2046,7 @@ }]>]; let hasFolder = 1; - let assemblyFormat = "operands attr-dict `:` type(operands)"; + let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2039,10 +2039,12 @@ //===----------------------------------------------------------------------===// OpFoldResult RankOp::fold(ArrayRef operands) { - // Constant fold rank when the rank of the tensor is known. + // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); - if (auto tensorType = type.dyn_cast()) - return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); + if (auto shapedType = type.dyn_cast()) + if (shapedType.hasRank()) + return IntegerAttr::get(IndexType::get(getContext()), + shapedType.getRank()); return IntegerAttr(); } diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -686,6 +686,18 @@ // ----- +// CHECK-LABEL: func @fold_rank_memref +func @fold_rank_memref(%arg0 : memref) -> (index) { + // Fold a rank into a constant + // CHECK-NEXT: [[C2:%.+]] = constant 2 : index + %rank_0 = rank %arg0 : memref + + // CHECK-NEXT: return [[C2]] + return %rank_0 : index +} + +// ----- + // CHECK-LABEL: func @nested_isolated_region func @nested_isolated_region() { // CHECK-NEXT: func @isolated_op