diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1556,7 +1556,9 @@ The body region defines the tensor's elements. It takes index operands as its region arguments that span the index space. The element at the given - position is yielded with the `yield` operation (see `YieldOp`). + position is yielded with the `yield` operation (see `YieldOp`). There is + no defined ordering to the invocations of the body. It is conceptually + a "parallel map" operation. Example: diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -19,6 +19,7 @@ def StdBufferize : FunctionPass<"std-bufferize"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; + let dependentDialects = ["scf::SCFDialect"]; } #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -12,12 +12,67 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +namespace { +class BufferizeDynamicTensorFromElementsOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(DynamicTensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Allocate memory. + Location loc = op.getLoc(); + DynamicTensorFromElementsOp::Adaptor transformed(operands); + RankedTensorType tensorType = op.getType().cast(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value result = + rewriter.create(loc, memrefType, transformed.dynamicExtents()); + + // Collect loop bounds. + int64_t rank = tensorType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = + tensorType.isDynamicDim(i) + ? transformed.dynamicExtents()[nextDynamicIndex++] + : rewriter.create(loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop. + rewriter.create( + loc, lowerBounds, upperBounds, steps, + [&](OpBuilder &b, Location loc, ValueRange ivs) { + BlockAndValueMapping mapping; + mapping.map(op.body().getArguments(), ivs); + for (auto &nestedOp : op.getBody()->without_terminator()) + b.clone(nestedOp, mapping); + auto yieldOp = cast(op.getBody()->getTerminator()); + b.create(loc, mapping.lookup(yieldOp.value()), result, ivs); + b.create(loc); + }); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + namespace { class BufferizeExtractElementOp : public OpConversionPattern { public: @@ -73,8 +128,10 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(typeConverter, context); + patterns + .insert( + typeConverter, context); } namespace { @@ -86,9 +143,11 @@ ConversionTarget target(*context); target.addLegalDialect(); + target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRSCF MLIRStandard MLIRTransforms ) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h --- a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h @@ -9,6 +9,7 @@ #ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,5 +1,54 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s +// CHECK-LABEL: func @dynamic_tensor_from_elements( +// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { +// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +// CHECK: } +func @dynamic_tensor_from_elements(%arg: tensor<*xf32>, %rank: index) -> tensor { + %result = dynamic_tensor_from_elements %rank { + ^bb0(%i : index): + %elem = dim %arg, %i : tensor<*xf32> + yield %elem : index + } : tensor + return %result : tensor +} + +// Additional test that checks the logic for intermixed static and dynamic +// extents. +// +// CHECK-LABEL: func @dynamic_tensor_from_elements_static_and_dynamic( +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { +// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index +// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> +// CHECK: return %[[RET]] : tensor<16x?xindex> +// CHECK: } +func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { + %result = dynamic_tensor_from_elements %arg0 { + ^bb0(%i: index, %j: index): + %sum = addi %i, %j : index + yield %sum : index + } : tensor<16x?xindex> + return %result : tensor<16x?xindex> +} + // CHECK-LABEL: func @extract_element( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[IDX:.*]]: index) -> f32 {