diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -572,7 +572,7 @@ ``` }]; - let arguments = (ins AnyRankedOrUnrankedMemRef:$source, + let arguments = (ins AnyNon0RankedOrUnrankedMemRef:$source, Index:$index); let results = (outs Index:$result); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -115,7 +115,7 @@ ``` }]; - let arguments = (ins AnyTensor:$source, + let arguments = (ins AnyNon0RankedOrUnrankedTensor:$source, Index:$index); let results = (outs Index:$result); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -548,6 +548,12 @@ == }] # rank>)>]>; +// Whether a shaped type has a rank greater than or equal of the specified rank. +class HasRankGreaterOrEqualPred : And<[ + HasRankPred, + CPred<[{$_self.cast<::mlir::ShapedType>().getRank() >= }] # rank> +]>; + // Vector types. class VectorOf allowedTypes> : @@ -748,7 +754,16 @@ string summary = "ranked tensor"> : TensorOf; +class Non0RankedTensorOf allowedTypes> + : TensorOf], + "non-0-ranked.tensor">; + def AnyRankedTensor : RankedTensorOf<[AnyType]>; +def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>; +def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>; + +def AnyNon0RankedOrUnrankedTensor: + AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>; // Ranked tensor type with one of the specified types and ranks. class TensorRankOf allowedTypes, list ranks> @@ -782,13 +797,20 @@ class MemRefOf allowedTypes> : ShapedContainerType; +class Non0RankedMemRefOf allowedTypes> : + ConfinedType, [HasRankGreaterOrEqualPred<1>], + "non-0-ranked." # MemRefOf.summary, + "::mlir::MemRefType">; def AnyMemRef : MemRefOf<[AnyType]>; +def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>; class RankedOrUnrankedMemRefOf allowedTypes>: AnyTypeOf<[UnrankedMemRefOf, MemRefOf]>; def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>; +def AnyNon0RankedOrUnrankedMemRef: + AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>; // Memref declarations handle any memref, independent of rank, size, (static or // dynamic), layout, or memory space. diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1040,3 +1040,11 @@ %0 = memref.realloc %src : memref<256xf32> to memref return %0 : memref } + +// ----- + +// Asking the dimension of a 0-D shape doesn't make sense. +func.func @dim_0_ranked(%arg : memref, %arg1 : index) { + memref.dim %arg, %arg1 : memref // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref'}} + return +} diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,13 +1,13 @@ // RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @dim( -// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>, // CHECK-SAME: %[[INDEX:.*]]: index) -> index { -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref -// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32> +// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<*xf32> // CHECK: return %[[EXTENT]] : index -func.func @dim(%arg0: tensor, %arg1: index) -> index { - %0 = tensor.dim %arg0, %arg1 : tensor +func.func @dim(%arg0: tensor<*xf32>, %arg1: index) -> index { + %0 = tensor.dim %arg0, %arg1 : tensor<*xf32> return %0 : index } diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -8,6 +8,14 @@ // ----- +// Asking the dimension of a 0-D shape doesn't make sense. +func.func @dim_0_ranked(%arg : tensor, %arg1 : index) { + tensor.dim %arg, %arg1 : tensor // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor'}} + return +} + +// ----- + func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { // expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}} %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>