diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h @@ -0,0 +1,28 @@ +//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Patterns for combining composed subview ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ +#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ + +namespace mlir { + +// Forward declarations. +class MLIRContext; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp + ComposeSubView.cpp DecomposeCallGraphTypes.cpp ExpandOps.cpp FuncBufferize.cpp diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp @@ -0,0 +1,136 @@ +//===- ComposeSubView.cpp - Combining composed subview ops ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns for combining composed subview ops (i.e. subview +// of a subview becomes a single subview). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { + +namespace { + +// Replaces a subview of a subview with a single subview. Only supports subview +// ops with static sizes and static strides of 1 (both static and dynamic +// offsets are supported). +struct ComposeSubViewOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp op, + PatternRewriter &rewriter) const override { + // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that + // produces the input of the op we're rewriting (for 'SubViewOp' the input + // is called the "source" value). We can only combine them if both 'op' and + // 'sourceOp' are 'SubViewOp'. + auto sourceOp = op.source().getDefiningOp(); + if (!sourceOp) + return failure(); + + // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the + // output memref that are statically known to be equal to 1. We do not + // allow 'sourceOp' to be a rank-reducing subview because then our two + // 'SubViewOp's would have different numbers of offset/size/stride + // parameters (just difficult to deal with, not impossible if we end up + // needing it). + if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) { + return failure(); + } + + // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'. + SmallVector offsets, sizes, strides; + + // Because we only support input strides of 1, the output stride is also + // always 1. + if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { + Attribute attr = valueOrAttr.dyn_cast(); + return attr && attr.cast().getInt() == 1; + })) { + strides = SmallVector(sourceOp.getMixedStrides().size(), + rewriter.getI64IntegerAttr(1)); + } else { + return failure(); + } + + // The rules for calculating the new offsets and sizes are: + // * Multiple subview offsets for a given dimension compose additively. + // ("Offset by m" followed by "Offset by n" == "Offset by m + n") + // * Multiple sizes for a given dimension compose by taking the size of the + // final subview and ignoring the rest. ("Take m values" followed by "Take + // n values" == "Take n values") This size must also be the smallest one + // by definition (a subview needs to be the same size as or smaller than + // its source along each dimension; presumably subviews that are larger + // than their sources are disallowed by validation). + for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), + op.getMixedSizes())) { + auto opOffset = std::get<0>(it); + auto sourceOffset = std::get<1>(it); + auto opSize = std::get<2>(it); + + // We only support static sizes. + if (opSize.is()) { + return failure(); + } + + sizes.push_back(opSize); + Attribute opOffsetAttr = opOffset.dyn_cast(), + sourceOffsetAttr = sourceOffset.dyn_cast(); + + if (opOffsetAttr && sourceOffsetAttr) { + // If both offsets are static we can simply calculate the combined + // offset statically. + offsets.push_back(rewriter.getI64IntegerAttr( + opOffsetAttr.cast().getInt() + + sourceOffsetAttr.cast().getInt())); + } else { + // When either offset is dynamic, we must emit an additional affine + // transformation to add the two offsets together dynamically. + AffineExpr expr = rewriter.getAffineConstantExpr(0); + SmallVector affineApplyOperands; + for (auto valueOrAttr : {opOffset, sourceOffset}) { + if (auto attr = valueOrAttr.dyn_cast()) { + expr = expr + attr.cast().getInt(); + } else { + expr = + expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); + affineApplyOperands.push_back(valueOrAttr.get()); + } + } + + AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); + Value result = rewriter.create(op.getLoc(), map, + affineApplyOperands); + offsets.push_back(result); + } + } + + // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any + // uses it can be removed by a (separate) dead code elimination pass. + rewriter.replaceOpWithNewOp(op, sourceOp.source(), + offsets, sizes, strides); + return success(); + } +}; + +} // namespace + +void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert(context); +} + +} // namespace mlir diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/compose-subview.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -test-compose-subview -split-input-file | FileCheck %s + +// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456) +#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2304)> +#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456)> + +func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map1> { + // CHECK: subview %arg0[3, 384] [1, 128] [1, 1] + // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]> + %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map0> + %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map0> to memref<1x128xf32, #map1> + return %1 : memref<1x128xf32, #map1> +} + +// ----- + +// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745) +#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 1536)> +#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2688)> +#map2 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745)> + +func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, #map2> { + // CHECK: subview %arg0[3, 673] [1, 10] [1, 1] + // CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, [[MAP]]> + %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, #map0> + %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, #map0> to memref<2x128xf32, #map1> + %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, #map1> to memref<1x10xf32, #map2> + return %2 : memref<1x10xf32, #map2> +} + +// ----- + +// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1) +#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> + +func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> { + // CHECK: [[CST_3:%.*]] = constant 3 : index + %cst_1 = constant 1 : index + %cst_2 = constant 2 : index + // CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1] + // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]> + %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map> + %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map> + return %1 : memref<1x128xf32, #map> +} + +// ----- + +// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1) +#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)> + +func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> { + // CHECK: [[CST_3:%.*]] = constant 3 : index + %cst_2 = constant 2 : index + // CHECK: [[CST_384:%.*]] = constant 384 : index + %cst_128 = constant 128 : index + // CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1] + // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]> + %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map> + %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map> + return %1 : memref<1x128xf32, #map> +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ TestDataLayoutQuery.cpp TestDominance.cpp TestDynamicPipeline.cpp + TestComposeSubView.cpp TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp diff --git a/mlir/test/lib/Transforms/TestComposeSubView.cpp b/mlir/test/lib/Transforms/TestComposeSubView.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestComposeSubView.cpp @@ -0,0 +1,46 @@ +//===- TestComposeSubView.cpp - Test composed subviews --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the composed subview patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestComposeSubViewPass + : public PassWrapper { + void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override; +}; + +void TestComposeSubViewPass::getDependentDialects( + DialectRegistry ®istry) const { + registry.insert(); +} + +void TestComposeSubViewPass::runOnFunction() { + OwningRewritePatternList patterns(&getContext()); + populateComposeSubViewPatterns(patterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestComposeSubView() { + PassRegistration pass( + "test-compose-subview", "Test combining composed subviews"); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,6 +72,7 @@ void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); +void registerTestComposeSubView(); void registerTestGpuParallelLoopMappingPass(); void registerTestIRVisitorsPass(); void registerTestInterfaces(); @@ -148,6 +149,7 @@ test::registerTestDominancePass(); test::registerTestDynamicPipelinePass(); test::registerTestExpandTanhPass(); + test::registerTestComposeSubView(); test::registerTestGpuParallelLoopMappingPass(); test::registerTestIRVisitorsPass(); test::registerTestInterfaces();