diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h rename from mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h rename to mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h @@ -1,4 +1,4 @@ -//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===// +//===- ComposeSubView.h - Combining composed memref ops ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,19 +10,20 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ -#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ +#define MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ namespace mlir { - -// Forward declarations. class MLIRContext; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; +namespace memref { + void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +} // namespace memref } // namespace mlir -#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { class AffineDialect; +class StandardOpsDialect; namespace tensor { class TensorDialect; } // namespace tensor @@ -31,6 +32,9 @@ // Patterns //===----------------------------------------------------------------------===// +/// Collects a set of patterns to rewrite ops within the memref dialect. +void populateExpandOpsPatterns(RewritePatternSet &patterns); + /// Appends patterns for folding memref.subview ops into consumer load/store ops /// into `patterns`. void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); @@ -51,6 +55,11 @@ // Passes //===----------------------------------------------------------------------===// +/// Creates an instance of the ExpandOps pass that legalizes memref dialect ops +/// to be convertible to LLVM. For example, `memref.reshape` gets converted to +/// `memref_reinterpret_cast`. +std::unique_ptr createExpandOpsPass(); + /// Creates an operation pass to fold memref.subview ops into consumer /// load/store ops into `patterns`. std::unique_ptr createFoldSubViewOpsPass(); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -11,6 +11,12 @@ include "mlir/Pass/PassBase.td" +def ExpandOps : Pass<"memref-expand", "FuncOp"> { + let summary = "Legalize memref operations to be convertible to LLVM."; + let constructor = "mlir::memref::createExpandOpsPass()"; + let dependentDialects = ["StandardOpsDialect"]; +} + def FoldSubViewOps : Pass<"fold-memref-subview-ops"> { let summary = "Fold memref.subview ops into consumer load/store ops"; let description = [{ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -45,16 +45,6 @@ /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(unsigned alignment = 0); -/// Creates an instance of the StdExpand pass that legalizes Std -/// dialect ops to be convertible to LLVM. For example, -/// `std.arith.ceildivsi` gets transformed to a number of std operations, -/// which can be lowered to LLVM; `memref.reshape` gets converted to -/// `memref_reinterpret_cast`. -std::unique_ptr createStdExpandOpsPass(); - -/// Collects a set of patterns to rewrite ops within the Std dialect. -void populateStdExpandOpsPatterns(RewritePatternSet &patterns); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -18,11 +18,6 @@ "memref::MemRefDialect", "scf::SCFDialect"]; } -def StdExpandOps : Pass<"std-expand", "FuncOp"> { - let summary = "Legalize std operations to be convertible to LLVM."; - let constructor = "mlir::createStdExpandOpsPass()"; -} - def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,4 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms + ComposeSubView.cpp + ExpandOps.cpp FoldSubViewOps.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp @@ -15,6 +17,7 @@ MLIRInferTypeOpInterface MLIRMemRef MLIRPass + MLIRStandard MLIRTensor MLIRTransforms MLIRVector diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp rename from mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp rename to mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -11,8 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" - +#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" @@ -21,7 +20,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -namespace mlir { +using namespace mlir; namespace { @@ -128,9 +127,7 @@ } // namespace -void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { +void mlir::memref::populateComposeSubViewPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } - -} // namespace mlir diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp rename from mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp rename to mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -17,8 +17,8 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -120,13 +120,13 @@ } }; -struct StdExpandOpsPass : public StdExpandOpsBase { +struct ExpandOpsPass : public ExpandOpsBase { void runOnOperation() override { MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); - populateStdExpandOpsPatterns(patterns); - ConversionTarget target(getContext()); + memref::populateExpandOpsPatterns(patterns); + ConversionTarget target(ctx); target.addLegalDialect(); @@ -146,11 +146,11 @@ } // namespace -void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) { +void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } -std::unique_ptr mlir::createStdExpandOpsPass() { - return std::make_unique(); +std::unique_ptr mlir::memref::createExpandOpsPass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h --- a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h @@ -14,6 +14,7 @@ namespace mlir { class AffineDialect; +class StandardOpsDialect; // Forward declaration from Dialect.h template diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,8 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp - ComposeSubView.cpp DecomposeCallGraphTypes.cpp - ExpandOps.cpp FuncBufferize.cpp FuncConversions.cpp TensorConstantBufferize.cpp diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir rename from mlir/test/Dialect/Standard/expand-ops.mlir rename to mlir/test/Dialect/MemRef/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -std-expand %s -split-input-file | FileCheck %s +// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(Math) +add_subdirectory(MemRef) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SPIRV) diff --git a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt copy from mlir/test/lib/Dialect/StandardOps/CMakeLists.txt copy to mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,19 +1,16 @@ # Exclude tests from libMLIR.so -add_mlir_library(MLIRStandardOpsTestPasses - TestDecomposeCallGraphTypes.cpp +add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC - MLIRAffine MLIRPass - MLIRStandardOpsTransforms + MLIRMemRefTransforms MLIRTestDialect - MLIRTransformUtils ) -target_include_directories(MLIRStandardOpsTestPasses +target_include_directories(MLIRMemRefTestPasses PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../Test ${CMAKE_CURRENT_BINARY_DIR}/../Test diff --git a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp rename from mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp rename to mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp --- a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" +#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -35,7 +35,7 @@ void TestComposeSubViewPass::runOnOperation() { OwningRewritePatternList patterns(&getContext()); - populateComposeSubViewPatterns(patterns, &getContext()); + memref::populateComposeSubViewPatterns(patterns, &getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRStandardOpsTestPasses TestDecomposeCallGraphTypes.cpp - TestComposeSubView.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/mlir-cpu-runner/memref-reshape.mlir b/mlir/test/mlir-cpu-runner/memref-reshape.mlir --- a/mlir/test/mlir-cpu-runner/memref-reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref-reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-std -std-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \ +// RUN: mlir-opt %s -convert-scf-to-std -memref-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -18,6 +18,7 @@ MLIRGPUTestPasses MLIRLinalgTestPasses MLIRMathTestPasses + MLIRMemRefTestPasses MLIRSCFTestPasses MLIRShapeTestPasses MLIRSPIRVTestPasses