diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2256,6 +2256,20 @@ I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); + + let builders = [ + // Build a ReinterpretCastOp with mixed static and dynamic entries. + OpBuilder< + "MemRefType resultType, Value source, int64_t staticOffset, " + "ArrayRef staticSizes,ArrayRef staticStrides, " + "ValueRange offset, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a ReinterpretCastOp with all dynamic entries. + OpBuilder< + "MemRefType resultType, Value source, Value offset, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}">, + ]; + let extraClassDeclaration = extraBaseClassDeclaration # [{ // The result of the op is always a ranked memref. MemRefType getType() { return getResult().getType().cast(); } diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -25,6 +25,9 @@ /// Creates an instance of the ExpandAtomic pass. std::unique_ptr createExpandAtomicPass(); +void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2269,6 +2269,35 @@ // MemRefReinterpretCastOp //===----------------------------------------------------------------------===// +void mlir::MemRefReinterpretCastOp::build( + OpBuilder &b, OperationState &result, MemRefType resultType, Value source, + int64_t staticOffset, ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offset, ValueRange sizes, + ValueRange strides, ArrayRef attrs) { + build(b, result, resultType, source, offset, sizes, strides, + b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +/// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`, +/// `staticSizes` and `staticStrides` are automatically filled with +/// source-memref-rank sentinel values that encode dynamic entries. +void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result, + MemRefType resultType, Value source, + Value offset, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + unsigned rank = resultType.getRank(); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, resultType, source, + /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector, + staticStridesVector, offset, sizes, strides, attrs); +} + /// Print of the form: /// ``` /// `name` ssa-name to @@ -2390,41 +2419,6 @@ op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); - - // Extract source offset and strides. - int64_t resultOffset; - SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return failure(); - - // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); - if (resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") - << resultOffset << " instead of " << expectedOffset; - - // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultType.getShape(), - extractFromI64ArrayAttr(op.static_sizes())))) { - int64_t resultSize = std::get<0>(en.value()); - int64_t expectedSize = std::get<1>(en.value()); - if (resultSize != expectedSize) - return op.emitError("expected result type with size = ") - << expectedSize << " instead of " << resultSize - << " in dim = " << en.index(); - } - - // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { - int64_t resultStride = std::get<0>(en.value()); - int64_t expectedStride = std::get<1>(en.value()); - if (resultStride != expectedStride) - return op.emitError("expected result type with stride = ") - << expectedStride << " instead of " << resultStride - << " in dim = " << en.index(); - } return success(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp @@ -0,0 +1,70 @@ +//===- ExpandMemRefReshape.cpp - Code to perform expanding memref_reshape -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of MemRefReshapeOp into +// MemRefReinterpretCastOp. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with +/// MemRefReshapeOpLowering pattern, e.g. with "minf" or "maxf" attributes, to +/// `generic_atomic_rmw` with the expanded code. +struct MemRefReshapeOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemRefReshapeOp op, + PatternRewriter &rewriter) const final { + if (!op.shape().getType().cast().hasStaticShape()) + return failure(); + + int64_t rank = op.shape().getType().cast().getDimSize(0); + SmallVector sizes, strides; + sizes.resize(rank); + strides.resize(rank); + + Location loc = op.getLoc(); + Value stride = rewriter.create(loc, 1); + for (int i = rank - 1; i >= 0; --i) { + Value index = rewriter.create(loc, i); + Value size = rewriter.create(loc, op.shape(), index); + if (!size.getType().isa()) + size = rewriter.create(loc, size, rewriter.getIndexType()); + sizes[i] = size; + strides[i] = stride; + if (i > 0) + stride = rewriter.create(loc, stride, size); + } + SmallVector staticSizes(rank, ShapedType::kDynamicSize); + SmallVector staticStrides(rank, + ShapedType::kDynamicStrideOrOffset); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes, + staticStrides, /*offset=*/llvm::None, sizes, strides); + return success(); + } +}; + +} // namespace + +void mlir::populateExpandMemRefReshapePattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -test-expand-memref-reshape | FileCheck %s + +// CHECK-LABEL: func @memref_reshape( +func @memref_reshape(%input: memref<*xf32>, + %shape: memref<3xi32>) -> memref { + %result = memref_reshape %input(%shape) + : (memref<*xf32>, memref<3xi32>) -> memref + return %result : memref +} +// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, +// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32> +// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index +// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32> +// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index +// CHECK: [[STRIDE_0:%.*]] = muli [[SIZE_2]], [[SIZE_1]] : index +// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32> +// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index +// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]] + +// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]], +// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[SIZE_2]], [[C1]]] +// CHECK-SAME: : memref<*xf32> to memref diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -138,52 +138,6 @@ // ----- -// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - // expected-error @+1 {{expected result type with offset = 0 instead of 1}} - %out = memref_reinterpret_cast %in to - offset: [1], sizes: [10], strides: [1] - : memref to memref<10xf32> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch -func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) { - // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [2] - : memref to memref<10xf32> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - %c0 = constant 0 : index - %c10 = constant 10 : index - // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] - : memref to memref - return -} - -// ----- - // CHECK-LABEL: memref_reshape_element_type_mismatch func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { diff --git a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp @@ -0,0 +1,37 @@ +//===- TestExpandMemRefReshape.cpp - Test expansion of memref_reshape -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for expanding memref reshape. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestExpandMemRefReshapePass + : public PassWrapper { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestExpandMemRefReshapePass::runOnFunction() { + OwningRewritePatternList patterns; + populateExpandMemRefReshapePattern(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), patterns); +} + +namespace mlir { +void registerTestExpandMemRefReshapePass() { + PassRegistration pass( + "test-expand-memref-reshape", "Test expanding memref reshape"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -53,6 +53,7 @@ void registerTestDominancePass(); void registerTestDialect(DialectRegistry &); void registerTestDynamicPipelinePass(); +void registerTestExpandMemRefReshapePass(); void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); @@ -115,6 +116,7 @@ registerTestDynamicPipelinePass(); registerTestFunc(); registerTestExpandTanhPass(); + registerTestExpandMemRefReshapePass(); registerTestGpuMemoryPromotionPass(); registerTestInterfaces(); registerTestLinalgCodegenStrategy();