diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -178,6 +178,10 @@ LogicalResult RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { + // For now, this lowering supports only error-free types. + if (op.getType().isa()) + return failure(); + shape::RankOp::Adaptor transformed(operands); rewriter.replaceOpWithNewOp(op, transformed.shape(), 0); return success(); @@ -232,7 +236,7 @@ // Apply conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -84,6 +84,16 @@ // ----- +// Don't lower `rank` if type is not error-free. +// CHECK-LABEL: @rank +func @rank(%shape : !shape.shape) { + // CHECK: shape.rank + %rank = shape.rank %shape : !shape.shape -> !shape.size + return +} + +// ----- + // Express `get_extent` as `std.dim` when it relies directly on the outcome of a // `shape_of` operation. // CHECK-LABEL: @get_extent_shape_of