diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,10 +14,27 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" -namespace mlir { +using namespace mlir; +using namespace mlir::shape; + namespace { /// Conversion patterns. +template +class BinaryOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SrcOpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + typename SrcOpTy::OperandAdaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.lhs(), + adaptor.rhs()); + return success(); + } +}; + class SizeToIndexOpConversion : public OpConversionPattern { public: @@ -90,17 +107,18 @@ } // namespace -void populateShapeToStandardConversionPatterns( +void mlir::populateShapeToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off patterns.insert< + BinaryOpConversion, + BinaryOpConversion, IndexToSizeOpConversion, SizeToIndexOpConversion>(ctx); // clang-format on } -std::unique_ptr> createConvertShapeToStandardPass() { +std::unique_ptr> +mlir::createConvertShapeToStandardPass() { return std::make_unique(); } - -} // namespace mlir diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -29,3 +29,16 @@ %size = shape.index_to_size %index return %size : !shape.size } + +// ----- + +// Lower binary ops. +// CHECK-LABEL: @binary_ops +// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) +func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) { + %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size + // CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index + %product = shape.mul %lhs, %rhs + // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index + return +}