diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -207,7 +207,7 @@ let hasFolder = 1; let hasCanonicalizer = 1; - let verifier = [{ return ::verify(*this); }]; + let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { @@ -252,7 +252,7 @@ }]; let hasFolder = 1; - let verifier = [{ return ::verify(*this); }]; + let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; } def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { @@ -325,7 +325,7 @@ $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict }]; - let verifier = [{ return ::verify(*this); }]; + let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; } @@ -412,7 +412,7 @@ let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict"; - let verifier = [{ return ::verify(*this); }]; + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -28,13 +28,35 @@ return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } -static bool isErrorPropagationPossible(ArrayRef operandTypes) { +static bool isErrorPropagationPossible(TypeRange operandTypes) { for (Type ty : operandTypes) if (ty.isa() || ty.isa() || ty.isa()) return true; return false; } +static LogicalResult verifySizeOrIndexOp(Operation *op) { + assert(op != nullptr && op->getNumResults() == 1); + Type resultTy = op->getResultTypes().front(); + if (isErrorPropagationPossible(op->getOperandTypes()) && + !resultTy.isa()) + return op->emitOpError() + << "if at least one of the operands can hold error values then the " + "result must be of type `size` to propagate them"; + return success(); +} + +static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { + assert(op != nullptr && op->getNumResults() == 1); + Type resultTy = op->getResultTypes().front(); + if (isErrorPropagationPossible(op->getOperandTypes()) && + !resultTy.isa()) + return op->emitOpError() + << "if at least one of the operands can hold error values then the " + "result must be of type `shape` to propagate them"; + return success(); +} + ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -542,23 +564,6 @@ // GetExtentOp //===----------------------------------------------------------------------===// -static LogicalResult verify(GetExtentOp op) { - Type shapeTy = op.shape().getType(); - Type dimTy = op.dim().getType(); - Type extentTy = op.extent().getType(); - if (isErrorPropagationPossible({shapeTy, dimTy})) { - if (!extentTy.isa()) - op.emitError() - << "if at least one of the operands can hold error values then the " - "result must be of type `size` to propagate them"; - } else { - if (extentTy.isa()) - op.emitError() << "if none of the operands can hold error values then " - "the result must be of type `index`"; - } - return success(); -} - Optional GetExtentOp::getConstantDim() { if (auto constSizeOp = dim().getDefiningOp()) return constSizeOp.value().getLimitedValue(); @@ -597,15 +602,6 @@ // RankOp //===----------------------------------------------------------------------===// -static LogicalResult verify(shape::RankOp op) { - if (op.shape().getType().isa() && - !op.rank().getType().isa()) - return op.emitOpError() - << "if operand is of type `shape` then the result must be of type " - "`size` to propagate potential errors"; - return success(); -} - OpFoldResult shape::RankOp::fold(ArrayRef operands) { auto shape = operands[0].dyn_cast_or_null(); if (!shape) @@ -680,21 +676,6 @@ // MulOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MulOp op) { - Type resultTy = op.result().getType(); - if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) { - if (!resultTy.isa()) - return op.emitOpError() - << "if at least one of the operands can hold error values then " - "the result must be of type `size` to propagate them"; - } else { - if (resultTy.isa()) - return op.emitError() << "if none of the operands can hold error values " - "then the result must be of type `index`"; - } - return success(); -} - OpFoldResult MulOp::fold(ArrayRef operands) { auto lhs = operands[0].dyn_cast_or_null(); if (!lhs) @@ -719,21 +700,6 @@ return builder.getIndexTensorAttr(type.getShape()); } -static LogicalResult verify(ShapeOfOp op) { - Type resultTy = op.result().getType(); - if (isErrorPropagationPossible(op.arg().getType())) { - if (!resultTy.isa()) - return op.emitOpError() - << "if operand is of type `value_shape` then the result must be " - "of type `shape` to propagate potential error shapes"; - } else { - if (resultTy != getExtentTensorType(op.getContext())) - return op.emitOpError() << "if operand is a shaped type then the result " - "must be an extent tensor"; - } - return success(); -} - //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -90,39 +90,21 @@ func @shape_of(%value_arg : !shape.value_shape, %shaped_arg : tensor) { - // expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}} + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor return } // ----- -func @shape_of(%value_arg : !shape.value_shape, - %shaped_arg : tensor) { - // expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}} - %1 = shape.shape_of %shaped_arg : tensor -> !shape.shape - return -} - -// ----- - func @rank(%arg : !shape.shape) { - // expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}} + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %0 = shape.rank %arg : !shape.shape -> index return } // ----- -func @get_extent_error_free(%arg : tensor) -> !shape.size { - %c0 = constant 0 : index - // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} - %result = shape.get_extent %arg, %c0 : tensor, index -> !shape.size - return %result : !shape.size -} - -// ----- - func @get_extent_error_possible(%arg : tensor) -> index { %c0 = shape.const_size 0 // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} @@ -132,14 +114,6 @@ // ----- -func @mul_error_free(%arg : index) -> !shape.size { - // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} - %result = shape.mul %arg, %arg : index, index -> !shape.size - return %result : !shape.size -} - -// ----- - func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %result = shape.mul %lhs, %rhs : !shape.size, index -> index