diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -49,8 +49,14 @@ LogicalResult matchAndRewrite(SrcOpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - typename SrcOpTy::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); + typename SrcOpTy::Adaptor transformed(operands); + + // For now, only error-free types are supported by this lowering. + if (op.getType().template isa()) + return failure(); + + rewriter.replaceOpWithNewOp(op, transformed.lhs(), + transformed.rhs()); return success(); } }; @@ -85,27 +91,31 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - ShapeOfOp::Adaptor transformed(operands); - auto loc = op.getLoc(); - auto tensorVal = transformed.arg(); - auto tensorTy = tensorVal.getType(); + + // For now, only error-free types are supported by this lowering. + if (op.getType().isa()) + return failure(); // For unranked tensors `shape_of` lowers to `scf` and the pattern can be // found in the corresponding pass. + ShapeOfOp::Adaptor transformed(operands); + Value tensorVal = transformed.arg(); + Type tensorTy = tensorVal.getType(); if (tensorTy.isa()) return failure(); // Build values for individual dimensions. SmallVector dimValues; - auto rankedTensorTy = tensorTy.cast(); + RankedTensorType rankedTensorTy = tensorTy.cast(); int64_t rank = rankedTensorTy.getRank(); + auto loc = op.getLoc(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { - auto dimVal = rewriter.create(loc, tensorVal, i); + Value dimVal = rewriter.create(loc, tensorVal, i); dimValues.push_back(dimVal); } else { int64_t dim = rankedTensorTy.getDimSize(i); - auto dimVal = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, dim); dimValues.push_back(dimVal); } } @@ -187,11 +197,18 @@ ConversionPatternRewriter &rewriter) const { GetExtentOp::Adaptor transformed(operands); - // Derive shape extent directly from shape origin if possible. - // This circumvents the necessity to materialize the shape in memory. + // For now, only error-free types are supported by this lowering. + if (op.getType().isa()) + return failure(); + + // Derive shape extent directly from shape origin if possible. This + // circumvents the necessity to materialize the shape in memory. if (auto shapeOfOp = op.shape().getDefiningOp()) { - rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), transformed.dim()); - return success(); + if (shapeOfOp.arg().getType().isa()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), + transformed.dim()); + return success(); + } } rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), @@ -241,7 +258,7 @@ // Apply conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -13,6 +13,30 @@ // ----- +// Don't lower binary ops when they operate on `shape.size`. +// CHECK-LABEL: @binary_ops_on_size +// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size) +func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) { + // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size + // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size + %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size + %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size + return +} + +// ----- + +// Don't lower `shape_of` with `shape.shape` type. +// CHECK-LABEL: @shape_of +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) +func @shape_of_stat(%arg : tensor<1x2x3xf32>) { + // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape + %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape + return +} + +// ----- + // Lower `shape_of` for statically shaped tensor. // CHECK-LABEL: @shape_of_stat // CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) @@ -55,6 +79,17 @@ // ----- +// Don't lower `get_extent` if it is of type `shape.size`. +// CHECK-LABEL: @get_extent +func @get_extent(%shape : tensor, %idx : !shape.size) -> !shape.size { + // CHECK: shape.get_extent + %result = shape.get_extent %shape, %idx + : tensor, !shape.size -> !shape.size + return %result : !shape.size +} + +// ----- + // Express `get_extent` as `std.dim` when it relies directly on the outcome of a // `shape_of` operation. // CHECK-LABEL: @get_extent_shape_of