diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -49,8 +49,17 @@ LogicalResult matchAndRewrite(SrcOpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - typename SrcOpTy::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); + typename SrcOpTy::Adaptor transformed(operands); + + // For now, only error-free types are supported by this lowering. + Value lhs = transformed.lhs(); + if (lhs.getType().isa()) + return failure(); + Value rhs = transformed.rhs(); + if (rhs.getType().isa()) + return failure(); + + rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } }; @@ -75,6 +84,10 @@ auto tensorVal = transformed.arg(); auto tensorTy = tensorVal.getType(); + // For now, only error-free types are supported by this lowering. + if (tensorTy.isa()) + return failure(); + // For unranked tensors `shape_of` lowers to `scf` and the pattern can be // found in the corresponding pass. if (tensorTy.isa()) @@ -118,11 +131,18 @@ ConversionPatternRewriter &rewriter) const { GetExtentOp::Adaptor transformed(operands); + // For now, only error-free types are supported by this lowering. + if (op.shape().getType().isa()) + return failure(); + // Derive shape extent directly from shape origin if possible. // This circumvents the necessity to materialize the shape in memory. if (auto shapeOfOp = op.shape().getDefiningOp()) { - rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), transformed.dim()); - return success(); + if (shapeOfOp.arg().getType().isa()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), + transformed.dim()); + return success(); + } } rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), @@ -178,7 +198,7 @@ // Apply conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); }