diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -18,27 +18,34 @@ using namespace mlir::shape; namespace { - /// Generated conversion patterns. #include "ShapeToStandardPatterns.inc" +} // namespace /// Conversion patterns. +namespace { class AnyOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AnyOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - AnyOp::Adaptor transformed(operands); - - // Replace `any` with its first operand. - // Any operand would be a valid substitution. - rewriter.replaceOp(op, {transformed.inputs().front()}); - return success(); - } + ConversionPatternRewriter &rewriter) const override; }; +} // namespace + +LogicalResult +AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + AnyOp::Adaptor transformed(operands); + + // Replace `any` with its first operand. + // Any operand would be a valid substitution. + rewriter.replaceOp(op, {transformed.inputs().front()}); + return success(); +} +namespace { template class BinaryOpConversion : public OpConversionPattern { public: @@ -53,98 +60,122 @@ return success(); } }; +} // namespace +namespace { class ShapeOfOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeOfOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ShapeOfOp::Adaptor transformed(operands); - auto loc = op.getLoc(); - auto tensorVal = transformed.arg(); - auto tensorTy = tensorVal.getType(); - - // For unranked tensors `shape_of` lowers to `scf` and the pattern can be - // found in the corresponding pass. - if (tensorTy.isa()) - return failure(); - - // Build values for individual dimensions. - SmallVector dimValues; - auto rankedTensorTy = tensorTy.cast(); - int64_t rank = rankedTensorTy.getRank(); - for (int64_t i = 0; i < rank; i++) { - if (rankedTensorTy.isDynamicDim(i)) { - auto dimVal = rewriter.create(loc, tensorVal, i); - dimValues.push_back(dimVal); - } else { - int64_t dim = rankedTensorTy.getDimSize(i); - auto dimVal = rewriter.create(loc, dim); - dimValues.push_back(dimVal); - } - } + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace - // Materialize extent tensor. - Value staticExtentTensor = - rewriter.create(loc, dimValues); - rewriter.replaceOpWithNewOp(op, staticExtentTensor, - op.getType()); - return success(); +LogicalResult ShapeOfOpConversion::matchAndRewrite( + ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + ShapeOfOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto tensorVal = transformed.arg(); + auto tensorTy = tensorVal.getType(); + + // For unranked tensors `shape_of` lowers to `scf` and the pattern can be + // found in the corresponding pass. + if (tensorTy.isa()) + return failure(); + + // Build values for individual dimensions. + SmallVector dimValues; + auto rankedTensorTy = tensorTy.cast(); + int64_t rank = rankedTensorTy.getRank(); + for (int64_t i = 0; i < rank; i++) { + if (rankedTensorTy.isDynamicDim(i)) { + auto dimVal = rewriter.create(loc, tensorVal, i); + dimValues.push_back(dimVal); + } else { + int64_t dim = rankedTensorTy.getDimSize(i); + auto dimVal = rewriter.create(loc, dim); + dimValues.push_back(dimVal); + } } -}; + // Materialize extent tensor. + Value staticExtentTensor = + rewriter.create(loc, dimValues); + rewriter.replaceOpWithNewOp(op, staticExtentTensor, + op.getType()); + return success(); +} + +namespace { class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstSizeOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op.getOperation(), - op.value().getSExtValue()); - return success(); - } + ConversionPatternRewriter &rewriter) const override; }; +} // namespace +LogicalResult ConstSizeOpConverter::matchAndRewrite( + ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op.getOperation(), + op.value().getSExtValue()); + return success(); +} + +namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(GetExtentOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - GetExtentOp::Adaptor transformed(operands); - - // Derive shape extent directly from shape origin if possible. - // This circumvents the necessity to materialize the shape in memory. - if (auto shapeOfOp = op.shape().getDefiningOp()) { - rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), - transformed.dim()); - return success(); - } + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), transformed.shape(), - ValueRange{transformed.dim()}); +LogicalResult GetExtentOpConverter::matchAndRewrite( + GetExtentOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + GetExtentOp::Adaptor transformed(operands); + + // Derive shape extent directly from shape origin if possible. + // This circumvents the necessity to materialize the shape in memory. + if (auto shapeOfOp = op.shape().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), transformed.dim()); return success(); } -}; + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + transformed.shape(), + ValueRange{transformed.dim()}); + return success(); +} + +namespace { class RankOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::RankOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - shape::RankOp::Adaptor transformed(operands); - rewriter.replaceOpWithNewOp(op.getOperation(), transformed.shape(), - 0); - return success(); - } + ConversionPatternRewriter &rewriter) const override; }; +} // namespace + +LogicalResult +RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + shape::RankOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), transformed.shape(), 0); + return success(); +} +namespace { /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -161,39 +192,42 @@ }); } }; +} // namespace +namespace { /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { - void runOnOperation() override { - // Setup type conversion. - MLIRContext &ctx = getContext(); - ShapeTypeConverter typeConverter(&ctx); - - // Setup target legality. - ConversionTarget target(ctx); - target.addLegalDialect(); - target.addLegalOp(); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()) && - typeConverter.isLegal(&op.getBody()); - }); - - // Setup conversion patterns. - OwningRewritePatternList patterns; - populateShapeToStandardConversionPatterns(patterns, &ctx); - populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); - - // Apply conversion. - auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) - signalPassFailure(); - } + void runOnOperation() override; }; - } // namespace +void ConvertShapeToStandardPass::runOnOperation() { + // Setup type conversion. + MLIRContext &ctx = getContext(); + ShapeTypeConverter typeConverter(&ctx); + + // Setup target legality. + ConversionTarget target(ctx); + target.addLegalDialect(); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + + // Setup conversion patterns. + OwningRewritePatternList patterns; + populateShapeToStandardConversionPatterns(patterns, &ctx); + populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); + + // Apply conversion. + auto module = getOperation(); + if (failed(applyFullConversion(module, target, patterns))) + signalPassFailure(); +} + void mlir::populateShapeToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { populateWithGenerated(ctx, &patterns);