diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -17,45 +17,46 @@ using namespace mlir; using namespace mlir::shape; +using namespace mlir::scf; namespace { /// Converts `shape.reduce` to `scf.for`. -struct ReduceOpConverter : public OpRewritePattern { +struct ReduceOpConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ReduceOp op, - PatternRewriter &rewriter) const final; + LogicalResult + matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; }; } // namespace LogicalResult -ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, - PatternRewriter &rewriter) const { - auto loc = reduceOp.getLoc(); +ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands. + if (!op.shape().getType().isa()) + return failure(); + + auto loc = op.getLoc(); + shape::ReduceOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); - Value extentTensor = rewriter.create( - loc, - RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()), - reduceOp.shape()); - Value size = - rewriter.create(loc, rewriter.getIndexType(), extentTensor, zero); + Type indexTy = rewriter.getIndexType(); + Value rank = rewriter.create(loc, indexTy, transformed.shape(), zero); auto loop = rewriter.create( - loc, zero, size, one, reduceOp.initVals(), - [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { - Value indexExtent = b.create(loc, extentTensor, iv); - Value sizeExtent = b.create(loc, indexExtent); + loc, zero, rank, one, op.initVals(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + Value extent = b.create(loc, transformed.shape(), iv); - SmallVector mapped_values{iv, sizeExtent}; - mapped_values.append(args.begin(), args.end()); + SmallVector mappedValues{iv, extent}; + mappedValues.append(args.begin(), args.end()); BlockAndValueMapping mapping; - Block *reduceBody = reduceOp.getBody(); - mapping.map(reduceBody->getArguments(), mapped_values); + Block *reduceBody = op.getBody(); + mapping.map(reduceBody->getArguments(), mappedValues); for (auto &nested : reduceBody->without_terminator()) b.clone(nested, mapping); @@ -65,7 +66,7 @@ b.create(loc, mappedResults); }); - rewriter.replaceOp(reduceOp, loop.getResults()); + rewriter.replaceOp(op, loop.getResults()); return success(); } @@ -138,8 +139,8 @@ // Setup target legality. ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalOp(); + target.addLegalDialect(); + target.addLegalOp(); // Apply conversion. if (failed(applyPartialConversion(getFunction(), target, patterns))) diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -1,31 +1,26 @@ // RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s // CHECK-LABEL: @shape_reduce -// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size -func @shape_reduce(%shape : !shape.shape) -> !shape.size { - %init = shape.const_size 1 - %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { - ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): - %new_acc = shape.mul %acc, %dim - shape.yield %new_acc : !shape.size +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @shape_reduce(%shape : tensor) -> index { + %init = constant 1 : index + %num_elements = shape.reduce(%shape, %init) : tensor -> index { + ^bb0(%index : index, %extent : index, %acc: index): + %new_acc = muli %acc, %extent : index + shape.yield %new_acc : index } - return %num_elements : !shape.size + return %num_elements : index } -// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1 -// CHECK-NEXT: [[C0:%.*]] = constant 0 : index -// CHECK-NEXT: [[C1:%.*]] = constant 1 : index - -// CHECK-NEXT: [[EXTENTS:%.*]] = shape.to_extent_tensor [[SHAPE]] -// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor - -// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]] -// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]]) -// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]] -// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]] -// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]] -// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size +// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index +// CHECK-NEXT: %[[C1:.*]] = constant 1 : index +// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor +// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) +// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]] +// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index +// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index // CHECK-NEXT: } -// CHECK-NEXT: return [[RESULT]] : !shape.size +// CHECK-NEXT: return %[[RESULT]] : index // -----