diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -431,8 +431,10 @@ let arguments = (ins Variadic:$results); let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result", - [{ /* nothing to do */ }]> + OpBuilder<"OpBuilder &builder, OperationState &state", + [{ /* nothing to do */ }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, Value result", + [{ build(builder, state, ValueRange({result})); }]> ]; } #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -109,6 +109,7 @@ FloatAttr getF32FloatAttr(float value); FloatAttr getF64FloatAttr(double value); + IntegerAttr getI1IntegerAttr(int8_t value); IntegerAttr getI8IntegerAttr(int8_t value); IntegerAttr getI16IntegerAttr(int16_t value); IntegerAttr getI32IntegerAttr(int32_t value); diff --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt @@ -15,5 +15,6 @@ MLIRShape MLIRPass MLIRSCF + MLIRShapeToStandard MLIRTransforms ) diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" #include "../PassDetail.h" +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -17,48 +18,49 @@ using namespace mlir; using namespace mlir::shape; +using namespace mlir::scf; +/// Conversion patterns. namespace { /// Converts `shape.reduce` to `scf.for`. -struct ReduceOpConverter : public OpRewritePattern { +struct ReduceOpConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ReduceOp op, - PatternRewriter &rewriter) const final; + LogicalResult + matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, - PatternRewriter &rewriter) const { +ReduceOpConverter::matchAndRewrite(shape::ReduceOp reduceOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + shape::ReduceOp::Adaptor transformed(operands); auto loc = reduceOp.getLoc(); + auto indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); + Value size = rewriter.create(loc, indexTy, transformed.shape(), zero); Value one = rewriter.create(loc, 1); - Value extentTensor = rewriter.create( - loc, - RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()), - reduceOp.shape()); - Value size = - rewriter.create(loc, rewriter.getIndexType(), extentTensor, zero); - auto loop = rewriter.create( - loc, zero, size, one, reduceOp.initVals(), + loc, zero, size, one, transformed.initVals(), [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { - Value indexExtent = b.create(loc, extentTensor, iv); - Value sizeExtent = b.create(loc, indexExtent); - - SmallVector mapped_values{iv, sizeExtent}; - mapped_values.append(args.begin(), args.end()); + // TODO: Consider to determine the extent with a `std.dim` operation if + // the shape originates in a tensor. + Value extent = b.create(loc, transformed.shape(), iv); + // Copy `shape.reduce` body without terminator and substitute values. + SmallVector mappedValues{iv, extent}; + mappedValues.append(args.begin(), args.end()); BlockAndValueMapping mapping; Block *reduceBody = reduceOp.getBody(); - mapping.map(reduceBody->getArguments(), mapped_values); + mapping.map(reduceBody->getArguments(), mappedValues); for (auto &nested : reduceBody->without_terminator()) b.clone(nested, mapping); + // Copy terminator and substitute values. SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); @@ -69,6 +71,24 @@ return success(); } +/// Type conversions. +class ShapeTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + ShapeTypeConverter(MLIRContext *ctx) { + // Add default pass-through conversion. + addConversion([&](Type type) { return type; }); + + addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); + addConversion([ctx](ShapeType type) { + return RankedTensorType::get({ShapedType::kDynamicSize}, + IndexType::get(ctx)); + }); + } +}; + +/// Conversion pass. namespace { struct ConvertShapeToSCFPass : public ConvertShapeToSCFBase { @@ -77,15 +97,22 @@ } // namespace void ConvertShapeToSCFPass::runOnFunction() { + // Setup type conversion. MLIRContext &ctx = getContext(); + ShapeTypeConverter typeConverter(&ctx); OwningRewritePatternList patterns; populateShapeToSCFConversionPatterns(patterns, &ctx); + populateShapeToStandardConversionPatterns(patterns, &ctx); + populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalOp(); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -165,6 +165,10 @@ return IntegerAttr::get(getIntegerType(8), APInt(8, value)); } +IntegerAttr Builder::getI1IntegerAttr(int8_t value) { + return IntegerAttr::get(getIntegerType(1), APInt(1, value)); +} + IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { if (type.isIndex()) return IntegerAttr::get(type, APInt(64, value)); diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -1,28 +1,23 @@ // RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s -// CHECK-LABEL: shape_reduce -// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size { -func @shape_reduce(%shape : !shape.shape) -> !shape.size { +// CHECK-LABEL: @shape_reduce +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) +func @shape_reduce(%shape : !shape.shape) { %init = shape.const_size 1 %num_elements = shape.reduce(%shape, %init) -> !shape.size { ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): %new_acc = shape.mul %acc, %dim shape.yield %new_acc : !shape.size } - return %num_elements : !shape.size + return } -// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1 -// CHECK-NEXT: [[C0:%.*]] = constant 0 : index -// CHECK-NEXT: [[C1:%.*]] = constant 1 : index - -// CHECK-NEXT: [[EXTENTS:%.*]] = shape.to_extent_tensor [[SHAPE]] -// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor - -// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]] -// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]]) -// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]] -// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]] -// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]] -// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size -// CHECK-NEXT: } -// CHECK-NEXT: return [[RESULT]] : !shape.size +// CHECK: %[[INIT:.*]] = constant 1 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[SIZE:.*]] = dim %[[SHAPE]], %[[C0]] : tensor +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[SIZE]] +// CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) { +// CHECK: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]] +// CHECK: %[[ACC_NEXT:.*]] = muli %[[ACC]], %[[EXTENT]] +// CHECK: scf.yield %[[ACC_NEXT]] : index +// CHECK: }