diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -272,6 +272,20 @@ } //===----------------------------------------------------------------------===// +// VectorToSCF +//===----------------------------------------------------------------------===// + +def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> { + let summary = "Lower the operations from the vector dialect into the SCF " + "dialect"; + let constructor = "mlir::createConvertVectorToSCFPass()"; + let options = [ + Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false", + "Perform full unrolling when converting vector transfers to SCF">, + ]; +} + +//===----------------------------------------------------------------------===// // VectorToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -14,6 +14,7 @@ namespace mlir { class MLIRContext; class OwningRewritePatternList; +class Pass; /// Control whether unrolling is used when lowering vector transfer ops to SCF. /// @@ -164,6 +165,10 @@ OwningRewritePatternList &patterns, MLIRContext *context, const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); +/// Create a pass to convert a subset of vector ops to SCF. +std::unique_ptr createConvertVectorToSCFPass( + const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); + } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -56,6 +56,11 @@ /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorTransformsOptions & + setVectorTransformsOptions(VectorContractLowering opt) { + vectorContractLowering = opt; + return *this; + } }; /// Collect a set of transformation patterns that are related to contracting diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -28,6 +28,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -124,6 +124,89 @@ return res; } +template +LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, + TransferOp xferOp, unsigned &align) { + Type elementTy = + typeConverter.convertType(xferOp.getMemRefType().getElementType()); + if (!elementTy) + return failure(); + + auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); + align = dataLayout.getPrefTypeAlignment( + elementTy.cast().getUnderlyingType()); + return success(); +} + +static LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferReadOp xferOp, + ArrayRef operands, Value dataPtr) { + rewriter.replaceOpWithNewOp(xferOp, dataPtr); + return success(); +} + +static LogicalResult +replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferReadOp xferOp, ArrayRef operands, + Value dataPtr, Value mask) { + auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + VectorType fillType = xferOp.getVectorType(); + Value fill = rewriter.create(loc, fillType, xferOp.padding()); + fill = rewriter.create(loc, toLLVMTy(fillType), fill); + + Type vecTy = typeConverter.convertType(xferOp.getVectorType()); + if (!vecTy) + return failure(); + + unsigned align; + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + return failure(); + + rewriter.replaceOpWithNewOp( + xferOp, vecTy, dataPtr, mask, ValueRange{fill}, + rewriter.getI32IntegerAttr(align)); + return success(); +} + +static LogicalResult +replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferWriteOp xferOp, + ArrayRef operands, Value dataPtr) { + auto adaptor = TransferWriteOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); + return success(); +} + +static LogicalResult +replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, + TransferWriteOp xferOp, ArrayRef operands, + Value dataPtr, Value mask) { + unsigned align; + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) + return failure(); + + auto adaptor = TransferWriteOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + xferOp, adaptor.vector(), dataPtr, mask, + rewriter.getI32IntegerAttr(align)); + return success(); +} + +static TransferReadOpOperandAdaptor +getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { + return TransferReadOpOperandAdaptor(operands); +} + +static TransferWriteOpOperandAdaptor +getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { + return TransferWriteOpOperandAdaptor(operands); +} + namespace { /// Conversion pattern for a vector.matrix_multiply. @@ -767,108 +850,6 @@ } }; -LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, - Type type, LLVM::LLVMType &llvmType, - unsigned &align) { - auto convertedType = typeConverter.convertType(type); - if (!convertedType) - return failure(); - - llvmType = convertedType.template cast(); - auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); - align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); - return success(); -} - -LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferReadOp xferOp, - ArrayRef operands, Value dataPtr) { - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); - rewriter.replaceOpWithNewOp(xferOp, dataPtr); - return success(); -} - -LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, - Location loc, TransferReadOp xferOp, - ArrayRef operands, - Value dataPtr, Value mask) { - auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; - VectorType fillType = xferOp.getVectorType(); - Value fill = rewriter.create(loc, fillType, xferOp.padding()); - fill = rewriter.create(loc, toLLVMTy(fillType), fill); - - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); - - rewriter.replaceOpWithNewOp( - xferOp, vecTy, dataPtr, mask, ValueRange{fill}, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferWriteOp xferOp, - ArrayRef operands, Value dataPtr) { - auto adaptor = TransferWriteOpOperandAdaptor(operands); - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); - rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); - return success(); -} - -LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, - Location loc, TransferWriteOp xferOp, - ArrayRef operands, - Value dataPtr, Value mask) { - auto adaptor = TransferWriteOpOperandAdaptor(operands); - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); - - rewriter.replaceOpWithNewOp( - xferOp, adaptor.vector(), dataPtr, mask, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -static TransferReadOpOperandAdaptor -getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { - return TransferReadOpOperandAdaptor(operands); -} - -static TransferWriteOpOperandAdaptor -getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { - return TransferWriteOpOperandAdaptor(operands); -} - -bool isMinorIdentity(AffineMap map, unsigned rank) { - if (map.getNumResults() < rank) - return false; - unsigned startDim = map.getNumDims() - rank; - for (unsigned i = 0; i < rank; ++i) - if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext())) - return false; - return true; -} - /// Conversion pattern that converts a 1-D vector transfer read/write op in a /// sequence of: /// 1. Bitcast or addrspacecast to vector form. @@ -892,8 +873,10 @@ if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) return failure(); - if (!isMinorIdentity(xferOp.permutation_map(), - xferOp.getVectorType().getRank())) + if (xferOp.permutation_map() != + AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), + xferOp.getVectorType().getRank(), + op->getContext())) return failure(); auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -13,6 +13,8 @@ #include #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" + +#include "../PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" @@ -29,6 +31,8 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; using namespace mlir::edsc; @@ -349,7 +353,7 @@ } } // namespace - + /// Analyzes the `transfer` to find an access dimension along the fastest remote /// MemRef dimension. If such a dimension with coalescing properties is found, /// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of @@ -435,7 +439,7 @@ } namespace mlir { - + template VectorTransferRewriter::VectorTransferRewriter( VectorTransferToSCFOptions options, MLIRContext *context) @@ -631,3 +635,28 @@ } // namespace mlir +namespace { + +struct ConvertVectorToSCFPass + : public ConvertVectorToSCFBase { + ConvertVectorToSCFPass() = default; + ConvertVectorToSCFPass(const ConvertVectorToSCFPass &pass) {} + ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { + this->fullUnroll = options.unroll; + } + + void runOnFunction() override { + OwningRewritePatternList patterns; + auto *context = getFunction().getContext(); + populateVectorToSCFConversionPatterns( + patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll)); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace + +std::unique_ptr +mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { + return std::make_unique(options); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -818,7 +818,7 @@ // CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> : // CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>"> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} : +// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : // CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>"> // @@ -850,7 +850,7 @@ // // 5. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 128 : i32} : +// CHECK-SAME: {alignment = 4 : i32} : // CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*"> func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> { diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-convert-vector-to-scf -split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL +// RUN: mlir-opt %s -convert-vector-to-scf -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -20,7 +20,6 @@ TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp - TestVectorToSCFConversion.cpp TestVectorTransforms.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp deleted file mode 100644 --- a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp +++ /dev/null @@ -1,48 +0,0 @@ -//===- TestVectorToSCFConversion.cpp - Test VectorTransfers lowering ------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include - -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; - -namespace { - -struct TestVectorToSCFPass - : public PassWrapper { - TestVectorToSCFPass() = default; - TestVectorToSCFPass(const TestVectorToSCFPass &pass) {} - - Option fullUnroll{ - *this, "full-unroll", - llvm::cl::desc( - "Perform full unrolling when converting vector transfers to SCF"), - llvm::cl::init(false)}; - - void runOnFunction() override { - OwningRewritePatternList patterns; - auto *context = &getContext(); - populateVectorToSCFConversionPatterns( - patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll)); - applyPatternsAndFoldGreedily(getFunction(), patterns); - } -}; - -} // end anonymous namespace - -namespace mlir { -void registerTestVectorToSCFPass() { - PassRegistration pass( - "test-convert-vector-to-scf", - "Converts vector transfer ops to loops over scalars and vector casts"); -} -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -62,7 +62,6 @@ void registerTestParallelismDetection(); void registerTestGpuParallelLoopMappingPass(); void registerTestVectorConversions(); -void registerTestVectorToSCFPass(); void registerVectorizerTestPass(); } // namespace mlir @@ -133,7 +132,6 @@ registerTestParallelismDetection(); registerTestGpuParallelLoopMappingPass(); registerTestVectorConversions(); - registerTestVectorToSCFPass(); registerVectorizerTestPass(); } #endif