diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -550,7 +550,8 @@ - `index` is in the range [-rank(operand),rank(operand)] }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index); + let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, + Shape_SizeOrIndexType:$index); let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -14,9 +14,9 @@ func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.const_shape [2, 3] : !shape.shape // CHECK: shape.const_shape [4, 5] : !shape.shape - %c2 = constant 2 : i32 + %c2 = constant 2 : index %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape - %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, index) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape } @@ -28,9 +28,9 @@ func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.const_shape [2, 3, 4] : !shape.shape // CHECK: shape.const_shape [5] : !shape.shape - %c-1 = constant -1 : i32 + %c-1 = constant -1 : index %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape - %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, index) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape } @@ -40,9 +40,9 @@ // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.split_at - %c5 = constant 5 : i32 + %c5 = constant 5 : index %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape - %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) + %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, index) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape }