diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -235,9 +235,10 @@ an error then it returns an error size. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, - Shape_SizeType:$dim); - let results = (outs Shape_SizeType:$extent); - let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict"; + Shape_SizeOrIndexType:$dim); + let results = (outs Shape_SizeOrIndexType:$extent); + let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` " + "type($extent) attr-dict"; let builders = [ // Builder that allows passing a constant dimension as a simple integer. @@ -251,6 +252,7 @@ }]; let hasFolder = 1; + let verifier = [{ return ::verify(*this); }]; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -535,10 +535,30 @@ // GetExtentOp //===----------------------------------------------------------------------===// +static LogicalResult verify(GetExtentOp op) { + Type shapeTy = op.shape().getType(); + Type dimTy = op.dim().getType(); + Type extentTy = op.extent().getType(); + bool errorPropagationPossible = + shapeTy.isa() || dimTy.isa(); + if (errorPropagationPossible) { + if (!extentTy.isa()) + op.emitError() + << "if at least one of the operands can hold error values then the " + "result must be of type `size` to propagate them"; + } else { + if (extentTy.isa()) + op.emitError() << "if none of the operands can hold error values then " + "the result must be of type `index`"; + } + return success(); +} + Optional GetExtentOp::getConstantDim() { - if (auto constSizeOp = dim().getDefiningOp()) { + if (auto constSizeOp = dim().getDefiningOp()) return constSizeOp.value().getLimitedValue(); - } + if (auto constantOp = dim().getDefiningOp()) + return constantOp.value().cast().getInt(); return llvm::None; } @@ -558,8 +578,14 @@ int64_t dim) { auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); - Value dimValue = builder.create(loc, dimAttr); - build(builder, result, shape, dimValue); + if (shape.getType().isa()) { + Value dim = builder.create(loc, dimAttr); + build(builder, result, builder.getType(), shape, dim); + } else { + Value dim = + builder.create(loc, builder.getIndexType(), dimAttr); + build(builder, result, builder.getIndexType(), shape, dim); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -136,28 +136,25 @@ // `shape_of` operation. // CHECK-LABEL: @get_extent_shape_of // CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index -func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size) - -> !shape.size { +func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index { // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> // CHECK: return %[[RESULT]] : index %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor - %result = shape.get_extent %shape, %idx : tensor - return %result : !shape.size + %result = shape.get_extent %shape, %idx : tensor, index -> index + return %result : index } // ----- -// Express `get_extent` as `std.extract_element` when it relies directly on the -// outcome of a `from_extent_tensor` operation. +// Express `get_extent` as `std.extract_element`. // CHECK-LABEL: @get_extent_from_extent_tensor // CHECK-SAME: (%[[EXTENTS:.*]]: tensor, %[[IDX:.*]]: index) -> index -func @get_extent_from_extent_tensor(%extents : tensor, - %idx : !shape.size) -> !shape.size { +func @get_extent_from_extent_tensor(%extents : tensor, %idx : index) + -> index { // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor // CHECK: return %[[RESULT]] : index - %shape = shape.from_extent_tensor %extents : tensor - %result = shape.get_extent %shape, %idx : !shape.shape - return %result : !shape.size + %result = shape.get_extent %extents, %idx : tensor, index -> index + return %result : index } // ----- diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -235,13 +235,49 @@ // ----- +// Basic folding. +// CHECK-LABEL: func @basic +func @basic() -> index { + // CHECK: constant 2 : index + %0 = shape.const_shape [0, 1, 2] : tensor + %c2 = constant 2 : index + %1 = shape.get_extent %0, %c2 : tensor, index -> index + return %1 : index +} + +// ----- + +// Should not fold. +// CHECK-LABEL: func @out_of_bounds +func @out_of_bounds() -> index { + // CHECK: shape.const_shape + // CHECK: shape.get_extent + %0 = shape.const_shape [0, 1, 2] : tensor + %c3 = constant 3 : index + %1 = shape.get_extent %0, %c3 : tensor, index -> index + return %1 : index +} + +// ----- + +// Should not fold. +// CHECK-LABEL: func @not_const +func @not_const(%arg0: tensor) -> index { + // CHECK: shape.get_extent + %c3 = constant 3 : index + %0 = shape.get_extent %arg0, %c3 : tensor, index -> index + return %0 : index +} + +// ----- + // Basic folding. // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 - %0 = shape.const_shape [0, 1, 2] : tensor + %0 = shape.const_shape [0, 1, 2] : !shape.shape %c2 = shape.const_size 2 - %1 = shape.get_extent %0, %c2 : tensor + %1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size return %1 : !shape.size } @@ -252,9 +288,9 @@ func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent - %0 = shape.const_shape [0, 1, 2] : tensor + %0 = shape.const_shape [0, 1, 2] : !shape.shape %c3 = shape.const_size 3 - %1 = shape.get_extent %0, %c3 : tensor + %1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size return %1 : !shape.size } @@ -262,14 +298,13 @@ // Should not fold. // CHECK-LABEL: func @not_const -func @not_const(%arg0: tensor) -> !shape.size { +func @not_const(%arg0 : !shape.shape) -> !shape.size { // CHECK: shape.get_extent %c3 = shape.const_size 3 - %0 = shape.get_extent %arg0, %c3 : tensor + %0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size return %0 : !shape.size } - // ----- // cstr_eq with non-constant but known equal shapes can be removed. // CHECK-LABEL: func @f diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -109,3 +109,21 @@ %1 = shape.rank %arg : tensor -> !shape.size } +// ----- + +func @get_extent_error_free(%arg : tensor) -> !shape.size { + %c0 = constant 0 : index + // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} + %result = shape.get_extent %arg, %c0 : tensor, index -> !shape.size + return %result : !shape.size +} + +// ----- + +func @get_extent_error_possible(%arg : tensor) -> index { + %c0 = shape.const_size 0 + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} + %result = shape.get_extent %arg, %c0 : tensor, !shape.size -> index + return %result : index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -163,13 +163,20 @@ func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size { %c0 = shape.const_size 0 - %result = shape.get_extent %arg, %c0 : !shape.shape + %result = shape.get_extent %arg, %c0 : + !shape.shape, !shape.size -> !shape.size return %result : !shape.size } -func @get_extent_on_extent_tensor(%arg : tensor) -> !shape.size { +func @get_extent_on_extent_tensor(%arg : tensor) -> index { + %c0 = constant 0 : index + %result = shape.get_extent %arg, %c0 : tensor, index -> index + return %result : index +} + +func @get_extent_on_mixed_operands(%arg : tensor) -> !shape.size { %c0 = shape.const_size 0 - %result = shape.get_extent %arg, %c0 : tensor + %result = shape.get_extent %arg, %c0 : tensor, !shape.size -> !shape.size return %result : !shape.size }