diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2262,6 +2262,20 @@ I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); + + let builders = [ + // Build a ReinterpretCastOp with mixed static and dynamic entries. + OpBuilder< + "MemRefType resultType, Value source, int64_t staticOffset, " + "ArrayRef staticSizes, ArrayRef staticStrides, " + "ValueRange offset, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a ReinterpretCastOp with all dynamic entries. + OpBuilder< + "MemRefType resultType, Value source, Value offset, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}">, + ]; + let extraClassDeclaration = extraBaseClassDeclaration # [{ // The result of the op is always a ranked memref. MemRefType getType() { return getResult().getType().cast(); } @@ -2312,7 +2326,7 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$source, - MemRefRankOf<[AnySignlessInteger], [1]>:$shape + MemRefRankOf<[AnySignlessInteger, Index], [1]>:$shape ); let results = (outs AnyRankedOrUnrankedMemRef:$result); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -25,6 +25,9 @@ /// Creates an instance of the ExpandAtomic pass. std::unique_ptr createExpandAtomicPass(); +void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -13,5 +13,6 @@ LINK_LIBS PUBLIC MLIRLLVMIR + MLIRStandardOpsTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -3666,6 +3667,7 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns); + populateExpandMemRefReshapePattern(patterns, &converter.getContext()); } /// Convert a non-empty list of types to be returned from a function into a diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2269,6 +2269,34 @@ // MemRefReinterpretCastOp //===----------------------------------------------------------------------===// +void mlir::MemRefReinterpretCastOp::build( + OpBuilder &b, OperationState &result, MemRefType resultType, Value source, + int64_t staticOffset, ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offset, ValueRange sizes, + ValueRange strides, ArrayRef attrs) { + build(b, result, resultType, source, offset, sizes, strides, + b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +/// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`, +/// `staticSizes` and `staticStrides` are automatically filled with +/// source-memref-rank sentinel values that encode dynamic entries. +void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result, + MemRefType resultType, Value source, + Value offset, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + unsigned rank = resultType.getRank(); + SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector( + rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, resultType, source, + /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector, + staticStridesVector, offset, sizes, strides, attrs); +} + /// Print of the form: /// ``` /// `name` ssa-name to @@ -2391,18 +2419,6 @@ op.strides()))) return failure(); - // Extract source offset and strides. - int64_t resultOffset; - SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return failure(); - - // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); - if (resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") - << resultOffset << " instead of " << expectedOffset; - // Match sizes in result memref type and in static_sizes attribute. for (auto &en : llvm::enumerate(llvm::zip(resultType.getShape(), @@ -2415,15 +2431,31 @@ << " in dim = " << en.index(); } - // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { - int64_t resultStride = std::get<0>(en.value()); - int64_t expectedStride = std::get<1>(en.value()); - if (resultStride != expectedStride) - return op.emitError("expected result type with stride = ") - << expectedStride << " instead of " << resultStride - << " in dim = " << en.index(); + // Match offset and strides in static_offset and static_strides attributes if + // result memref type has an affine map specified. + if (!resultType.getAffineMaps().empty()) { + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return failure(); + + // Match offset in result memref type and in static_offsets attribute. + int64_t expectedOffset = + extractFromI64ArrayAttr(op.static_offsets()).front(); + if (resultOffset != expectedOffset) + return op.emitError("expected result type with offset = ") + << resultOffset << " instead of " << expectedOffset; + + // Match strides in result memref type and in static_strides attribute. + for (auto &en : llvm::enumerate(llvm::zip( + resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + int64_t resultStride = std::get<0>(en.value()); + int64_t expectedStride = std::get<1>(en.value()); + if (resultStride != expectedStride) + return op.emitError("expected result type with stride = ") + << expectedStride << " instead of " << resultStride + << " in dim = " << en.index(); + } } return success(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp ExpandAtomic.cpp + ExpandMemRefReshape.cpp ExpandTanh.cpp FuncConversions.cpp diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp @@ -0,0 +1,70 @@ +//===- ExpandMemRefReshape.cpp - Code to perform expanding memref_reshape -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of MemRefReshapeOp into +// MemRefReinterpretCastOp. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Converts `memref_reshape` that has a target shape of a statically-known +/// size to `memref_reinterpret_cast`. +struct MemRefReshapeOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemRefReshapeOp op, + PatternRewriter &rewriter) const final { + auto shapeType = op.shape().getType().cast(); + if (!shapeType.hasStaticShape()) + return failure(); + + int64_t rank = shapeType.cast().getDimSize(0); + SmallVector sizes, strides; + sizes.resize(rank); + strides.resize(rank); + + Location loc = op.getLoc(); + Value stride = rewriter.create(loc, 1); + for (int i = rank - 1; i >= 0; --i) { + Value index = rewriter.create(loc, i); + Value size = rewriter.create(loc, op.shape(), index); + if (!size.getType().isa()) + size = rewriter.create(loc, size, rewriter.getIndexType()); + sizes[i] = size; + strides[i] = stride; + if (i > 0) + stride = rewriter.create(loc, stride, size); + } + SmallVector staticSizes(rank, ShapedType::kDynamicSize); + SmallVector staticStrides(rank, + ShapedType::kDynamicStrideOrOffset); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes, + staticStrides, /*offset=*/llvm::None, sizes, strides); + return success(); + } +}; + +} // namespace + +void mlir::populateExpandMemRefReshapePattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -test-expand-memref-reshape | FileCheck %s + +// CHECK-LABEL: func @memref_reshape( +func @memref_reshape(%input: memref<*xf32>, + %shape: memref<3xi32>) -> memref { + %result = memref_reshape %input(%shape) + : (memref<*xf32>, memref<3xi32>) -> memref + return %result : memref +} +// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, +// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32> +// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index +// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32> +// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index +// CHECK: [[STRIDE_0:%.*]] = muli [[SIZE_2]], [[SIZE_1]] : index +// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> +// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index + +// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]] +// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]], +// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[SIZE_2]], [[C1]]] +// CHECK-SAME: : memref<*xf32> to memref diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -140,10 +140,10 @@ // CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - // expected-error @+1 {{expected result type with offset = 0 instead of 1}} + // expected-error @+1 {{expected result type with offset = 2 instead of 1}} %out = memref_reinterpret_cast %in to offset: [1], sizes: [10], strides: [1] - : memref to memref<10xf32> + : memref to memref<10xf32, offset: 2, strides: [1]> return } @@ -164,8 +164,8 @@ func @memref_reinterpret_cast_offset_mismatch(%in: memref) { // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [2] - : memref to memref<10xf32> + offset: [2], sizes: [10], strides: [2] + : memref to memref<10xf32, offset: 2, strides: [1]> return } diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTestTransforms TestAffineLoopParametricTiling.cpp TestBufferPlacement.cpp + TestExpandMemRefReshape.cpp TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp diff --git a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp @@ -0,0 +1,37 @@ +//===- TestExpandMemRefReshape.cpp - Test expansion of memref_reshape -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for expanding memref reshape. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestExpandMemRefReshapePass + : public PassWrapper { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestExpandMemRefReshapePass::runOnFunction() { + OwningRewritePatternList patterns; + populateExpandMemRefReshapePattern(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +namespace mlir { +void registerTestExpandMemRefReshapePass() { + PassRegistration pass( + "test-expand-memref-reshape", "Test expanding memref reshape"); +} +} // namespace mlir diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @main() -> () { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %prod = muli %i, %dim_y : index + %val = addi %prod, %j : index + %val_i64 = index_cast %val : index to i64 + %val_f32 = sitofp %val_i64 : i64 to f32 + store %val_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [0, 1, 2] + // CHECK-NEXT: [3, 4, 5] + + // Initialize shape. + %shape = alloc() : memref<2xindex> + %c2 = constant 2 : index + %c3 = constant 3 : index + store %c3, %shape[%c0] : memref<2xindex> + store %c2, %shape[%c1] : memref<2xindex> + + // Test cases. + call @reshape_ranked_memref_to_ranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_unranked_memref_to_ranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () + return +} + +func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %output = memref_reshape %input(%shape) + : (memref<2x3xf32>, memref<2xindex>) -> memref + + %unranked_output = memref_cast %output : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %output = memref_reshape %input(%shape) + : (memref<2x3xf32>, memref<2xindex>) -> memref + + %unranked_output = memref_cast %output : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -53,6 +53,7 @@ void registerTestDominancePass(); void registerTestDialect(DialectRegistry &); void registerTestDynamicPipelinePass(); +void registerTestExpandMemRefReshapePass(); void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); @@ -115,6 +116,7 @@ registerTestDynamicPipelinePass(); registerTestFunc(); registerTestExpandTanhPass(); + registerTestExpandMemRefReshapePass(); registerTestGpuMemoryPromotionPass(); registerTestInterfaces(); registerTestLinalgCodegenStrategy();