diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -292,4 +292,14 @@ let constructor = "mlir::createConvertVectorToLLVMPass()"; } +//===----------------------------------------------------------------------===// +// VectorToROCDL +//===----------------------------------------------------------------------===// + +def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> { + let summary = "Lower the operations from the vector dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertVectorToROCDLPass()"; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h @@ -0,0 +1,28 @@ +//===- VectorToROCDL.h - Convert Vector to ROCDL dialect ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_ +#define MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class LLVMTypeConverter; +class OwningRewritePatternList; +class ModuleOp; +template +class OperationPass; + +/// Collect a set of patterns to convert from the GPU dialect to ROCDL. +void populateVectorToROCDLConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Create a pass to convert vector operations to the LLVMIR dialect. +std::unique_ptr> createConvertVectorToROCDLPass(); + +} // namespace mlir +#endif // MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/GPU/Passes.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(ShapeToStandard) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -15,4 +15,5 @@ MLIRROCDLIR MLIRPass MLIRStandardToLLVM + MLIRVectorToROCDL ) diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -55,6 +56,7 @@ patterns.clear(); populateVectorToLLVMConversionPatterns(converter, patterns); + populateVectorToROCDLConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToROCDLConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); diff --git a/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRVectorToROCDL + VectorToROCDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToROCDL + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRROCDLIR + MLIRStandardToLLVM + MLIRVector + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -0,0 +1,201 @@ +//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to generate ROCDLIR operations for higher-level +// Vector operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +static TransferReadOpOperandAdaptor +getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { + return TransferReadOpOperandAdaptor(operands); +} + +static TransferWriteOpOperandAdaptor +getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { + return TransferWriteOpOperandAdaptor(operands); +} + +bool isMinorIdentity(AffineMap map, unsigned rank) { + if (map.getNumResults() < rank) + return false; + unsigned startDim = map.getNumDims() - rank; + for (unsigned i = 0; i < rank; ++i) + if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext())) + return false; + return true; +} + +LogicalResult replaceTransferOpWithMubuf( + ConversionPatternRewriter &rewriter, ArrayRef operands, + LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, + LLVM::LLVMType &vecTy, Value &dwordConfig, Value &int32zero, + Value &offsetSizeInBytes, Value &int1False) { + rewriter.replaceOpWithNewOp(xferOp, vecTy, dwordConfig, + int32zero, offsetSizeInBytes, + int1False, int1False); + return success(); +} + +LogicalResult replaceTransferOpWithMubuf( + ConversionPatternRewriter &rewriter, ArrayRef operands, + LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, + LLVM::LLVMType &vecTy, Value &dwordConfig, Value &int32zero, + Value &offsetSizeInBytes, Value &int1False) { + auto adaptor = TransferWriteOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + xferOp, adaptor.vector(), dwordConfig, int32zero, offsetSizeInBytes, + int1False, int1False); + + return success(); +} + +template +class VectorTransferConversion : public ConvertToLLVMPattern { +public: + explicit VectorTransferConversion(MLIRContext *context, + LLVMTypeConverter &typeConv) + : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, + typeConv) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto xferOp = cast(op); + auto adaptor = getTransferOpAdapter(xferOp, operands); + + if (xferOp.getVectorType().getRank() > 1 || + llvm::size(xferOp.indices()) == 0) + return failure(); + + if (!isMinorIdentity(xferOp.permutation_map(), + xferOp.getVectorType().getRank())) + return failure(); + + // Have it handled in vector->llvm conversion pass + if (!xferOp.isMaskedDim(0)) + return failure(); + + auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + LLVM::LLVMType vecTy = + toLLVMTy(xferOp.getVectorType()).template cast(); + unsigned vecWidth = vecTy.getVectorNumElements(); + Location loc = op->getLoc(); + + if (vecWidth != 1 && vecWidth != 2 && vecWidth != 4) + return failure(); + + // Obtain dataPtr and elementType from the memref + MemRefType memRefType = xferOp.getMemRefType(); + auto elementType = memRefType.getElementType(); + auto convertedPtrType = typeConverter.convertType(elementType) + .template cast() + .getPointerTo(0); + Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter, getModule()); + + if (memRefType.getMemorySpace() != 0) + dataPtr = rewriter.create(loc, convertedPtrType, + dataPtr); + + // Create a <4 x i32> dwordConfig with: + // Word 1 and 2: address of dataPtr + // Word 3: -1 + // Word 4: 0x27000 + SmallVector indices{0, 0, -1, 0x27000}; + Type i32Ty = rewriter.getIntegerType(32); + VectorType i32Vecx4 = VectorType::get(4, i32Ty); + Value constConfig = rewriter.create( + loc, toLLVMTy(i32Vecx4), + DenseElementsAttr::get(i32Vecx4, ArrayRef(indices))); + + // Treat first two element of <4 x i32> as i64, and save the dataPtr + // to it + Type i64Ty = rewriter.getIntegerType(64); + Value i64x2Ty = rewriter.create( + loc, + LLVM::LLVMType::getVectorTy( + toLLVMTy(i64Ty).template cast(), 2), + constConfig); + Value dataPtrAsI64 = rewriter.create( + loc, toLLVMTy(i64Ty).template cast(), dataPtr); + Value zero = createIndexConstant(rewriter, loc, 0); + Value dwordConfig = rewriter.create( + loc, + LLVM::LLVMType::getVectorTy( + toLLVMTy(i64Ty).template cast(), 2), + i64x2Ty, dataPtrAsI64, zero); + dwordConfig = + rewriter.create(loc, toLLVMTy(i32Vecx4), dwordConfig); + + // 2. Rewrite op as a buffer read or write. + Value int1False = rewriter.create( + loc, toLLVMTy(rewriter.getIntegerType(1)), + rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); + Value int32Zero = rewriter.create( + loc, toLLVMTy(i32Ty), + rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); + return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc, + xferOp, vecTy, dwordConfig, int32Zero, + int32Zero, int1False); + } +}; +} // end anonymous namespace + +void mlir::populateVectorToROCDLConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert, + VectorTransferConversion>(ctx, converter); +} + +namespace { +struct LowerVectorToROCDLPass + : public ConvertVectorToROCDLBase { + void runOnOperation() override; +}; +} // namespace + +void LowerVectorToROCDLPass::runOnOperation() { + LLVMTypeConverter converter(&getContext()); + OwningRewritePatternList patterns; + + populateVectorToROCDLConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + LLVMConversionTarget target(getContext()); + target.addLegalDialect(); + + if (failed(applyPartialConversion(getOperation(), target, patterns, + &converter))) { + signalPassFailure(); + } +} + +std::unique_ptr> +mlir::createConvertVectorToROCDLPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-vector-to-rocdl | FileCheck %s + +gpu.module @test_module{ +func @transfer_readx2(%A : memref, %base: index) -> vector<4xf32> { + %f0 = constant 0.0: f32 + %f = vector.transfer_read %A[%base], %f0 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<4xf32> + return %f: vector<4xf32> +} +// CHECK-LABEL: @transfer_readx2 +// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>"> + +func @transfer_readx4(%A : memref, %base: index) -> vector<4xf32> { + %f0 = constant 0.0: f32 + %f = vector.transfer_read %A[%base], %f0 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<4xf32> + return %f: vector<4xf32> +} +// CHECK-LABEL: @transfer_readx4 +// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>"> + +func @transfer_read_dwordConfig(%A : memref, %base: index) -> vector<4xf32> { + %f0 = constant 0.0: f32 + %f = vector.transfer_read %A[%base], %f0 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<4xf32> + return %f: vector<4xf32> +} +// CHECK-LABEL: @transfer_read_dwordConfig +// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} +// CHECK: [0, 0, -1, 159744] +// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]] +// CHECK: llvm.insertelement %[[i64]] + +} + diff --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @vectransfer(%arg0 : memref, %arg1 : memref) { + %cst = constant 1 : index + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst) + threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) { + %f0 = constant 0.0: f32 + %base = constant 0 : index + %f = vector.transfer_read %arg0[%base], %f0 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<2xf32> + + %c = addf %f, %f : vector<2xf32> + + %base1 = constant 1 : index + vector.transfer_write %c, %arg1[%base1] + {permutation_map = affine_map<(d0) -> (d0)>} : + vector<2xf32>, memref + + gpu.terminator + } + return +} + +// CHECK: [1.23, 2.46, 2.46, 1.23] +func @main() { + %cf1 = constant 1.0 : f32 + + %arg0 = alloc() : memref<4xf32> + %arg1 = alloc() : memref<4xf32> + + %22 = memref_cast %arg0 : memref<4xf32> to memref + %23 = memref_cast %arg1 : memref<4xf32> to memref + + %cast0 = memref_cast %22 : memref to memref<*xf32> + %cast1 = memref_cast %23 : memref to memref<*xf32> + + call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> () + + %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref) -> (memref) + %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref) -> (memref) + + call @vectransfer(%24, %26) : (memref, memref) -> () + call @print_memref_f32(%cast1) : (memref<*xf32>) -> () + return +} + +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref) -> (memref) +func @print_memref_f32(%ptr : memref<*xf32>)