diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -15,10 +15,26 @@ #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +using namespace mlir::shape; namespace { /// Conversion patterns. +template +class BinaryOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SrcOpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + typename SrcOpTy::OperandAdaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.lhs(), + adaptor.rhs()); + return success(); + } +}; + class FromExtentTensorOpConversion : public OpConversionPattern { public: @@ -128,6 +144,8 @@ OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off patterns.insert< + BinaryOpConversion, + BinaryOpConversion, FromExtentTensorOpConversion, IndexToSizeOpConversion, SizeToIndexOpConversion, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -62,3 +62,16 @@ : (tensor) -> !shape.shape return %shape : !shape.shape } + +// ----- + +// Lower binary ops. +// CHECK-LABEL: @binary_ops +// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) +func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) { + %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size + // CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index + %product = shape.mul %lhs, %rhs + // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index + return +}