diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -410,17 +410,6 @@ // StandardToSPIRV //===----------------------------------------------------------------------===// -def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> { - let summary = "Legalize standard ops for SPIR-V lowering"; - let description = [{ - The pass contains certain intra standard op conversions that are meant for - lowering to SPIR-V ops, e.g., folding subviews loads/stores to the original - loads/stores from/to the original memref. - }]; - let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()"; - let dependentDialects = ["spirv::SPIRVDialect"]; -} - def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> { let summary = "Convert Standard dialect to SPIR-V dialect"; let constructor = "mlir::createConvertStandardToSPIRVPass()"; diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h --- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h @@ -40,11 +40,6 @@ int64_t byteCountThreshold, RewritePatternSet &patterns); -/// Appends to a pattern list patterns to legalize ops that are not directly -/// lowered to SPIR-V. -void populateStdLegalizationPatternsForSPIRVLowering( - RewritePatternSet &patterns); - } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h --- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h @@ -20,9 +20,6 @@ /// Creates a pass to convert standard ops to SPIR-V ops. std::unique_ptr> createConvertStandardToSPIRVPass(); -/// Creates a pass to legalize ops that are not directly lowered to SPIR-V. -std::unique_ptr createLegalizeStdOpsForSPIRVLoweringPass(); - } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H diff --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef) +add_public_tablegen_target(MLIRMemRefPassIncGen) +add_dependencies(mlir-headers MLIRMemRefPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc MemRefPasses ./) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -0,0 +1,47 @@ +//===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares patterns and passes on MemRef operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace memref { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +/// Appends patterns for folding memref.subview ops into consumer load/store ops +/// into `patterns`. +void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an operation pass to fold memref.subview ops into consumer +/// load/store ops into `patterns`. +std::unique_ptr createFoldSubViewOpsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" + +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES +#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FoldSubViewOps : Pass<"fold-memref-subview-ops"> { + let summary = "Fold memref.subview ops into consumer load/store ops"; + let description = [{ + The pass folds loading/storing from/to subview ops to loading/storing + from/to the original memref. + }]; + let constructor = "mlir::memref::createFoldSubViewOpsPass()"; + let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"]; +} + + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES + diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" @@ -55,6 +56,7 @@ registerGpuSerializeToHsacoPass(); registerLinalgPasses(); LLVM::registerLLVMPasses(); + memref::registerMemRefPasses(); quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_conversion_library(MLIRStandardToSPIRV - LegalizeStandardForSPIRV.cpp StandardToSPIRV.cpp StandardToSPIRVPass.cpp diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1,23 +1,3 @@ -add_mlir_dialect_library(MLIRMemRef - IR/MemRefDialect.cpp - IR/MemRefOps.cpp - Utils/MemRefUtils.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect - - DEPENDS - MLIRStandardOpsIncGen - MLIRMemRefOpsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRDialect - MLIRInferTypeOpInterface - MLIRIR - MLIRStandard - MLIRTensor - MLIRViewLikeInterface -) +add_subdirectory(IR) +add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -16,6 +16,7 @@ MLIRDialect MLIRInferTypeOpInterface MLIRIR + MLIRMemRefUtils MLIRStandard MLIRTensor MLIRViewLikeInterface diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRMemRefTransforms + FoldSubViewOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef + + DEPENDS + MLIRMemRefPassIncGen + + LINK_LIBS PUBLIC + MLIRMemRef + MLIRPass + MLIRStandard + MLIRTransforms + MLIRVector +) + diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp rename from mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp rename to mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -1,4 +1,4 @@ -//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===// +//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,16 +6,13 @@ // //===----------------------------------------------------------------------===// // -// This transformation pass legalizes operations before the conversion to SPIR-V -// dialect to handle ops that cannot be lowered directly. +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. // //===----------------------------------------------------------------------===// -#include "../PassDetail.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" -#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -23,6 +20,49 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Given the 'indices' of an load/store operation where the memref is a result +/// of a subview op, returns the indices w.r.t to the source memref of the +/// subview op. For example +/// +/// %0 = ... : memref<12x42xf32> +/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to +/// memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// +/// could be folded into +/// +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +static LogicalResult +resolveSourceIndices(Location loc, PatternRewriter &rewriter, + memref::SubViewOp subViewOp, ValueRange indices, + SmallVectorImpl &sourceIndices) { + // TODO: Aborting when the offsets are static. There might be a way to fold + // the subview op with load even if the offsets have been canonicalized + // away. + SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); + auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); + assert(opRanges.size() == indices.size() && + "expected as many indices as rank of subview op result type"); + + // New indices for the load are the current indices * subview_stride + + // subview_offset. + sourceIndices.resize(indices.size()); + for (auto index : llvm::enumerate(indices)) { + auto offset = *(opOffsets.begin() + index.index()); + auto stride = *(opStrides.begin() + index.index()); + auto mul = rewriter.create(loc, index.value(), stride); + sourceIndices[index.index()] = + rewriter.create(loc, offset, mul).getResult(); + } + return success(); +} + /// Helpers to access the memref operand for each op. static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } @@ -34,6 +74,10 @@ return op.source(); } +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + namespace { /// Merges subview operation with load/transferRead operation. template @@ -101,62 +145,15 @@ } } // namespace -//===----------------------------------------------------------------------===// -// Utility functions for op legalization. -//===----------------------------------------------------------------------===// - -/// Given the 'indices' of an load/store operation where the memref is a result -/// of a subview op, returns the indices w.r.t to the source memref of the -/// subview op. For example -/// -/// %0 = ... : memref<12x42xf32> -/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to -/// memref<4x4xf32, offset=?, strides=[?, ?]> -/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> -/// -/// could be folded into -/// -/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> -static LogicalResult -resolveSourceIndices(Location loc, PatternRewriter &rewriter, - memref::SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { - // TODO: Aborting when the offsets are static. There might be a way to fold - // the subview op with load even if the offsets have been canonicalized - // away. - SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); - auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); - auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); - assert(opRanges.size() == indices.size() && - "expected as many indices as rank of subview op result type"); - - // New indices for the load are the current indices * subview_stride + - // subview_offset. - sourceIndices.resize(indices.size()); - for (auto index : llvm::enumerate(indices)) { - auto offset = *(opOffsets.begin() + index.index()); - auto stride = *(opStrides.begin() + index.index()); - auto mul = rewriter.create(loc, index.value(), stride); - sourceIndices[index.index()] = - rewriter.create(loc, offset, mul).getResult(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Folding SubViewOp and LoadOp/TransferReadOp. -//===----------------------------------------------------------------------===// - template LogicalResult LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const { auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp(); - if (!subViewOp) { + if (!subViewOp) return failure(); - } + SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) @@ -166,19 +163,15 @@ return success(); } -//===----------------------------------------------------------------------===// -// Folding SubViewOp and StoreOp/TransferWriteOp. -//===----------------------------------------------------------------------===// - template LogicalResult StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const { auto subViewOp = getMemRefOperand(storeOp).template getDefiningOp(); - if (!subViewOp) { + if (!subViewOp) return failure(); - } + SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) @@ -188,12 +181,7 @@ return success(); } -//===----------------------------------------------------------------------===// -// Hook for adding patterns. -//===----------------------------------------------------------------------===// - -void mlir::populateStdLegalizationPatternsForSPIRVLowering( - RewritePatternSet &patterns) { +void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewFolder, StoreOpOfSubViewFolder, @@ -202,23 +190,28 @@ } //===----------------------------------------------------------------------===// -// Pass for testing just the legalization patterns. +// Pass registration //===----------------------------------------------------------------------===// namespace { -struct SPIRVLegalization final - : public LegalizeStandardForSPIRVBase { + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" + +struct FoldSubViewOpsPass final + : public FoldSubViewOpsBase { void runOnOperation() override; }; + } // namespace -void SPIRVLegalization::runOnOperation() { +void FoldSubViewOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateStdLegalizationPatternsForSPIRVLowering(patterns); + memref::populateFoldSubViewOpPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), std::move(patterns)); } -std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { - return std::make_unique(); +std::unique_ptr memref::createFoldSubViewOpsPass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRMemRefUtils + MemRefUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces +) + diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" using namespace mlir; diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir deleted file mode 100644 --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s - -module { - -//===----------------------------------------------------------------------===// -// memref.subview -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @fold_static_stride_subview -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index -func @fold_static_stride_subview - (%arg0 : memref<12x32xf32>, %arg1 : index, - %arg2 : index, %arg3 : index, %arg4 : index) { - // CHECK-DAG: %[[C2:.*]] = constant 2 - // CHECK-DAG: %[[C3:.*]] = constant 3 - // CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]] - // CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]] - // CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]] - // CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]] - // CHECK: %[[LOADVAL:.*]] = memref.load %[[ARG0]][%[[T1]], %[[T3]]] - // CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]] - // CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]] - // CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]] - // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]] - // CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]] - // CHECK: memref.store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]] - %0 = memref.subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]> - %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]> - %2 = math.sqrt %1 : f32 - memref.store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]> - return -} - -} // end module diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir rename from mlir/test/Conversion/StandardToSPIRV/legalization.mlir rename to mlir/test/Dialect/MemRef/fold-subview-ops.mlir --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s // CHECK-LABEL: @fold_static_stride_subview_with_load // CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" @@ -40,7 +41,7 @@ applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); - passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass()); + passManager.addPass(memref::createFoldSubViewOpsPass()); passManager.addPass(createConvertGPUToSPIRVPass()); OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createLowerABIAttributesPass());