diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -440,7 +440,7 @@ arguments. }]; - let arguments = (outs Shape_SizeOrIndexType:$arg); + let arguments = (ins Shape_SizeOrIndexType:$arg); let results = (outs Index:$result); let assemblyFormat = "$arg attr-dict `:` type($arg)"; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -56,6 +56,21 @@ }; } // namespace +namespace { +class ConstSizeOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); + return success(); + } +}; +} // namespace + namespace { class ShapeOfOpConversion : public OpConversionPattern { public: @@ -136,6 +151,27 @@ return success(); } +namespace { +class ToExtentTensorOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ToExtentTensorOpAdaptor adaptor(operands); + + if (!adaptor.input().getType().isa()) + return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); + + rewriter.replaceOpWithNewOp(op, adaptor.input(), + op.getType()); + return success(); + } +}; +} // namespace + namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -244,9 +280,11 @@ BinaryOpConversion, ConstShapeOpConverter, BinaryOpConversion, + ConstSizeOpConversion, GetExtentOpConverter, RankOpConverter, - ShapeOfOpConversion>(ctx); + ShapeOfOpConversion, + ToExtentTensorOpConversion>(ctx); // clang-format on } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -753,7 +753,7 @@ // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; - return {}; + return impl::foldCastOp(*this); } void SizeToIndexOp::getCanonicalizationPatterns( @@ -812,7 +812,7 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) - return nullptr; + return impl::foldCastOp(*this); Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -149,3 +149,28 @@ return %result : tensor } +// ----- + +// Lower 'const_size` to `std.constant` +// CHECK-LABEL: @const_size +func @const_size() -> index { + // CHECK: %[[RES:.*]] = constant 42 : index + %size = shape.const_size 42 + %result = shape.size_to_index %size : !shape.size + // CHECK: return %[[RES]] + return %result : index +} + +// ----- + +// Lower `to_extent_tensor` to `std.tensor_cast` +// Fold to_extent_tensor when already on tensor. +// CHECK-LABEL: @to_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor +func @to_extent_tensor(%arg: tensor) -> tensor<3xindex> { + // CHECK-NOT: to_extent_tensor + // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor to tensor<3xindex + %casted = shape.to_extent_tensor %arg : tensor -> tensor<3xindex> + // CHECK: return %[[RES]] + return %casted : tensor<3xindex> +} diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -774,3 +774,22 @@ return %result : !shape.size } +// ----- + +// Fold index_cast when already on index. +// CHECK-LABEL: @fold_index_cast_on_index +func @fold_index_cast_on_index(%arg: index) -> index { + // CHECK-NOT: size_to_index + %casted = shape.size_to_index %arg : index + return %casted : index +} + +// ----- + +// Fold to_extent_tensor when already on tensor. +// CHECK-LABEL: @fold_to_extent_tensor_on_tensor +func @fold_to_extent_tensor_on_tensor(%arg: tensor) -> tensor { + // CHECK-NOT: to_extent_tensor + %casted = shape.to_extent_tensor %arg : tensor -> tensor + return %casted : tensor +}