diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -0,0 +1,28 @@ +//===- MemRefToSPIRV.h - MemRef to SPIR-V Patterns --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert MemRef dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H +#define MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating MemRef ops +/// to SPIR-V ops. +void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h @@ -0,0 +1,25 @@ +//===- MemRefToSPIRVPass.h - MemRef to SPIR-V Passes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides passes to convert MemRef dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H +#define MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Creates a pass to convert MemRef ops to SPIR-V ops. +std::unique_ptr> createConvertMemRefToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_MEMREFTOSPIRV_MEMREFTOSPIRVPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -26,6 +26,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -296,6 +296,16 @@ ]; } +//===----------------------------------------------------------------------===// +// MemRefToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> { + let summary = "Convert MemRef dialect to SPIR-V dialect"; + let constructor = "mlir::createConvertMemRefToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + //===----------------------------------------------------------------------===// // OpenACCToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(MathToLLVM) add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToLLVM) +add_subdirectory(MemRefToSPIRV) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -15,6 +15,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" @@ -59,6 +60,10 @@ SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); populateGPUToSPIRVPatterns(typeConverter, patterns); + + // TODO: Change SPIR-V conversion to be progressive and remove the following + // patterns. + populateMemRefToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToSPIRV/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRMemRefToSPIRV + MemRefToSPIRV.cpp + MemRefToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMemRef + MLIRPass + MLIRSPIRV + MLIRSPIRVConversion + MLIRSupport + MLIRTransformUtils + ) + diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -0,0 +1,486 @@ +//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert MemRef dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "memref-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Returns the offset of the value in `targetBits` representation. +/// +/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. +/// It's assumed to be non-negative. +/// +/// When accessing an element in the array treating as having elements of +/// `targetBits`, multiple values are loaded in the same time. The method +/// returns the offset where the `srcIdx` locates in the value. For example, if +/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is +/// located at (x % 4) * 8. Because there are four elements in one i32, and one +/// element has 8 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +/// Returns an adjusted spirv::AccessChainOp. Based on the +/// extension/capabilities, certain integer bitwidths `sourceBits` might not be +/// supported. During conversion if a memref of an unsupported type is used, +/// load/stores to this memref need to be modified to use a supported higher +/// bitwidth `targetBits` and extracting the required bits. For an accessing a +/// 1D array (spv.array or spv.rt_array), the last index is modified to load the +/// bits needed. The extraction of the actual bits needed are handled +/// separately. Note that this only works for a 1-D tensor. +static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, + spirv::AccessChainOp op, + int sourceBits, int targetBits, + OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + const auto loc = op.getLoc(); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr attr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, attr); + auto lastDim = op->getOperand(op.getNumOperands() - 1); + auto indices = llvm::to_vector<4>(op.indices()); + // There are two elements if this is a 1-D tensor. + assert(indices.size() == 2); + indices.back() = builder.create(loc, lastDim, idx); + Type t = typeConverter.convertType(op.component_ptr().getType()); + return builder.create(loc, t, op.base_ptr(), indices); +} + +/// Returns the shifted `targetBits`-bit value with the given offset. +static Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { + Type targetType = builder.getIntegerType(targetBits); + Value result = builder.create(loc, value, mask); + return builder.create(loc, targetType, result, + offset); +} + +/// Returns true if the allocations of type `t` can be lowered to SPIR-V. +static bool isAllocationSupported(MemRefType t) { + // Currently only support workgroup local memory allocations with static + // shape and int or float or vector of int or float element type. + if (!(t.hasStaticShape() && + SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) + return false; + Type elementType = t.getElementType(); + if (auto vecType = elementType.dyn_cast()) + elementType = vecType.getElementType(); + return elementType.isIntOrFloat(); +} + +/// Returns the scope to use for atomic operations use for emulating store +/// operations of unsupported integer bitwidths, based on the memref +/// type. Returns None on failure. +static Optional getAtomicOpScope(MemRefType t) { + Optional storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace( + t.getMemorySpaceAsInt()); + if (!storageClass) + return {}; + switch (*storageClass) { + case spirv::StorageClass::StorageBuffer: + return spirv::Scope::Device; + case spirv::StorageClass::Workgroup: + return spirv::Scope::Workgroup; + default: { + } + } + return {}; +} + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +// Note that DRR cannot be used for the patterns in this file: we may need to +// convert type along the way, which requires ConversionPattern. DRR generates +// normal RewritePattern. + +namespace { + +/// Converts an allocation operation to SPIR-V. Currently only supports lowering +/// to Workgroup memory when the size is constant. Note that this pattern needs +/// to be applied in a pass that runs at least at spv.module scope since it wil +/// ladd global variables into the spv.module. +class AllocOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Removed a deallocation if it is a supported allocation. Currently only +/// removes deallocation if the memory space is workgroup memory. +class DeallocOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts memref.load to spv.Load. +class IntLoadOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts memref.load to spv.Load. +class LoadOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts memref.store to spv.Store on integers. +class IntStoreOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts memref.store to spv.Store. +class StoreOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +LogicalResult +AllocOpPattern::matchAndRewrite(memref::AllocOp operation, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + MemRefType allocType = operation.getType(); + if (!isAllocationSupported(allocType)) + return operation.emitError("unhandled allocation type"); + + // Get the SPIR-V type for the allocation. + Type spirvType = getTypeConverter()->convertType(allocType); + + // Insert spv.GlobalVariable for this allocation. + Operation *parent = + SymbolTable::getNearestSymbolTable(operation->getParentOp()); + if (!parent) + return failure(); + Location loc = operation.getLoc(); + spirv::GlobalVariableOp varOp; + { + OpBuilder::InsertionGuard guard(rewriter); + Block &entryBlock = *parent->getRegion(0).begin(); + rewriter.setInsertionPointToStart(&entryBlock); + auto varOps = entryBlock.getOps(); + std::string varName = + std::string("__workgroup_mem__") + + std::to_string(std::distance(varOps.begin(), varOps.end())); + varOp = rewriter.create(loc, spirvType, varName, + /*initializer=*/nullptr); + } + + // Get pointer to global variable at the current scope. + rewriter.replaceOpWithNewOp(operation, varOp); + return success(); +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// + +LogicalResult +DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + MemRefType deallocType = operation.memref().getType().cast(); + if (!isAllocationSupported(deallocType)) + return operation.emitError("unhandled deallocation type"); + rewriter.eraseOp(operation); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +LogicalResult +IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + memref::LoadOpAdaptor loadOperands(operands); + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.memref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + + auto &typeConverter = *getTypeConverter(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loc, rewriter); + + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; + Type pointeeType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType(); + Type structElemType = pointeeType.cast().getElementType(0); + Type dstType; + if (auto arrayType = structElemType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = structElemType.cast().getElementType(); + + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // If the rewrited load op has the same bit width, use the loading value + // directly. + if (srcBits == dstBits) { + rewriter.replaceOpWithNewOp(loadOp, + accessChainOp.getResult()); + return success(); + } + + // Assume that getElementPtr() works linearizely. If it's a scalar, the method + // still returns a linearized accessing. If the accessing is not linearized, + // there will be offset issues. + assert(accessChainOp.indices().size() == 2); + Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, + srcBits, dstBits, rewriter); + Value spvLoadOp = rewriter.create( + loc, dstType, adjustedPtr, + loadOp->getAttrOfType( + spirv::attributeName()), + loadOp->getAttrOfType("alignment")); + + // Shift the bits to the rightmost. + // ____XXXX________ -> ____________XXXX + Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); + Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); + Value result = rewriter.create( + loc, spvLoadOp.getType(), spvLoadOp, offset); + + // Apply the mask to extract corresponding bits. + Value mask = rewriter.create( + loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); + result = rewriter.create(loc, dstType, result, mask); + + // Apply sign extension on the loading value unconditionally. The signedness + // semantic is carried in the operator itself, we relies other pattern to + // handle the casting. + IntegerAttr shiftValueAttr = + rewriter.getIntegerAttr(dstType, dstBits - srcBits); + Value shiftValue = + rewriter.create(loc, dstType, shiftValueAttr); + result = rewriter.create(loc, dstType, result, + shiftValue); + result = rewriter.create(loc, dstType, result, + shiftValue); + + if (isBool) { + dstType = typeConverter.convertType(loadOp.getType()); + mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); + Value isOne = rewriter.create(loc, result, mask); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + result = rewriter.create(loc, dstType, isOne, one, zero); + } else if (result.getType().getIntOrFloatBitWidth() != + static_cast(dstBits)) { + result = rewriter.create(loc, dstType, result); + } + rewriter.replaceOp(loadOp, result); + + assert(accessChainOp.use_empty()); + rewriter.eraseOp(accessChainOp); + + return success(); +} + +LogicalResult +LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + memref::LoadOpAdaptor loadOperands(operands); + auto memrefType = loadOp.memref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto loadPtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, + loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + rewriter.replaceOpWithNewOp(loadOp, loadPtr); + return success(); +} + +LogicalResult +IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + memref::StoreOpAdaptor storeOperands(operands); + auto memrefType = storeOp.memref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + + auto loc = storeOp.getLoc(); + auto &typeConverter = *getTypeConverter(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), + storeOperands.indices(), loc, rewriter); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; + Type pointeeType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType(); + Type structElemType = pointeeType.cast().getElementType(0); + Type dstType; + if (auto arrayType = structElemType.dyn_cast()) + dstType = arrayType.getElementType(); + else + dstType = structElemType.cast().getElementType(); + + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + if (srcBits == dstBits) { + rewriter.replaceOpWithNewOp( + storeOp, accessChainOp.getResult(), storeOperands.value()); + return success(); + } + + // Since there are multi threads in the processing, the emulation will be done + // with atomic operations. E.g., if the storing value is i8, rewrite the + // StoreOp to + // 1) load a 32-bit integer + // 2) clear 8 bits in the loading value + // 3) store 32-bit value back + // 4) load a 32-bit integer + // 5) modify 8 bits in the loading value + // 6) store 32-bit value back + // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step + // 4 to step 6 are done by AtomicOr as another atomic step. + assert(accessChainOp.indices().size() == 2); + Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); + Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); + + // Create a mask to clear the destination. E.g., if it is the second i8 in + // i32, 0xFFFF00FF is created. + Value mask = rewriter.create( + loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); + Value clearBitsMask = + rewriter.create(loc, dstType, mask, offset); + clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); + + Value storeVal = storeOperands.value(); + if (isBool) { + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + storeVal = + rewriter.create(loc, dstType, storeVal, one, zero); + } + storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); + Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, + srcBits, dstBits, rewriter); + Optional scope = getAtomicOpScope(memrefType); + if (!scope) + return failure(); + Value result = rewriter.create( + loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, + clearBitsMask); + result = rewriter.create( + loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, + storeVal); + + // The AtomicOrOp has no side effect. Since it is already inserted, we can + // just remove the original StoreOp. Note that rewriter.replaceOp() + // doesn't work because it only accepts that the numbers of result are the + // same. + rewriter.eraseOp(storeOp); + + assert(accessChainOp.use_empty()); + rewriter.eraseOp(accessChainOp); + + return success(); +} + +LogicalResult +StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + memref::StoreOpAdaptor storeOperands(operands); + auto memrefType = storeOp.memref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto storePtr = + spirv::getElementPtr(*getTypeConverter(), memrefType, + storeOperands.memref(), storeOperands.indices(), + storeOp.getLoc(), rewriter); + rewriter.replaceOpWithNewOp(storeOp, storePtr, + storeOperands.value()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +namespace mlir { +void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add( + typeConverter, patterns.getContext()); +} +} // namespace mlir diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp @@ -0,0 +1,60 @@ +//===- MemRefToSPIRVPass.cpp - MemRef to SPIR-V Passes ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert standard dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" +#include "../PassDetail.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" + +using namespace mlir; + +namespace { +/// A pass converting MLIR MemRef operations into the SPIR-V dialect. +class ConvertMemRefToSPIRVPass + : public ConvertMemRefToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertMemRefToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + + // Use UnrealizedConversionCast as the bridge so that we don't need to pull in + // patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + target->addLegalOp(); + + RewritePatternSet patterns(context); + populateMemRefToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertMemRefToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRMemRefToSPIRV MLIRSPIRV MLIRSPIRVConversion MLIRStandardToSPIRV diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" #include "../PassDetail.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SCF/SCF.h" @@ -39,7 +40,11 @@ ScfToSPIRVContext scfContext; RewritePatternSet patterns(context); populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns); + + // TODO: Change SPIR-V conversion to be progressive and remove the following + // patterns. populateStandardToSPIRVPatterns(typeConverter, patterns); + populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" @@ -128,101 +127,6 @@ return builder.create(loc, type, isPositive, abs, absNegate); } -/// Returns the offset of the value in `targetBits` representation. -/// -/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. -/// It's assumed to be non-negative. -/// -/// When accessing an element in the array treating as having elements of -/// `targetBits`, multiple values are loaded in the same time. The method -/// returns the offset where the `srcIdx` locates in the value. For example, if -/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is -/// located at (x % 4) * 8. Because there are four elements in one i32, and one -/// element has 8 bits. -static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, - int targetBits, OpBuilder &builder) { - assert(targetBits % sourceBits == 0); - IntegerType targetType = builder.getIntegerType(targetBits); - IntegerAttr idxAttr = - builder.getIntegerAttr(targetType, targetBits / sourceBits); - auto idx = builder.create(loc, targetType, idxAttr); - IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); - auto srcBitsValue = - builder.create(loc, targetType, srcBitsAttr); - auto m = builder.create(loc, srcIdx, idx); - return builder.create(loc, targetType, m, srcBitsValue); -} - -/// Returns an adjusted spirv::AccessChainOp. Based on the -/// extension/capabilities, certain integer bitwidths `sourceBits` might not be -/// supported. During conversion if a memref of an unsupported type is used, -/// load/stores to this memref need to be modified to use a supported higher -/// bitwidth `targetBits` and extracting the required bits. For an accessing a -/// 1D array (spv.array or spv.rt_array), the last index is modified to load the -/// bits needed. The extraction of the actual bits needed are handled -/// separately. Note that this only works for a 1-D tensor. -static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, - spirv::AccessChainOp op, - int sourceBits, int targetBits, - OpBuilder &builder) { - assert(targetBits % sourceBits == 0); - const auto loc = op.getLoc(); - IntegerType targetType = builder.getIntegerType(targetBits); - IntegerAttr attr = - builder.getIntegerAttr(targetType, targetBits / sourceBits); - auto idx = builder.create(loc, targetType, attr); - auto lastDim = op->getOperand(op.getNumOperands() - 1); - auto indices = llvm::to_vector<4>(op.indices()); - // There are two elements if this is a 1-D tensor. - assert(indices.size() == 2); - indices.back() = builder.create(loc, lastDim, idx); - Type t = typeConverter.convertType(op.component_ptr().getType()); - return builder.create(loc, t, op.base_ptr(), indices); -} - -/// Returns the shifted `targetBits`-bit value with the given offset. -static Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { - Type targetType = builder.getIntegerType(targetBits); - Value result = builder.create(loc, value, mask); - return builder.create(loc, targetType, result, - offset); -} - -/// Returns true if the allocations of type `t` can be lowered to SPIR-V. -static bool isAllocationSupported(MemRefType t) { - // Currently only support workgroup local memory allocations with static - // shape and int or float or vector of int or float element type. - if (!(t.hasStaticShape() && - SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) - return false; - Type elementType = t.getElementType(); - if (auto vecType = elementType.dyn_cast()) - elementType = vecType.getElementType(); - return elementType.isIntOrFloat(); -} - -/// Returns the scope to use for atomic operations use for emulating store -/// operations of unsupported integer bitwidths, based on the memref -/// type. Returns None on failure. -static Optional getAtomicOpScope(MemRefType t) { - Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace( - t.getMemorySpaceAsInt()); - if (!storageClass) - return {}; - switch (*storageClass) { - case spirv::StorageClass::StorageBuffer: - return spirv::Scope::Device; - case spirv::StorageClass::Workgroup: - return spirv::Scope::Workgroup; - default: { - } - } - return {}; -} - //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -233,66 +137,6 @@ namespace { -/// Converts an allocation operation to SPIR-V. Currently only supports lowering -/// to Workgroup memory when the size is constant. Note that this pattern needs -/// to be applied in a pass that runs at least at spv.module scope since it wil -/// ladd global variables into the spv.module. -class AllocOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::AllocOp operation, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MemRefType allocType = operation.getType(); - if (!isAllocationSupported(allocType)) - return operation.emitError("unhandled allocation type"); - - // Get the SPIR-V type for the allocation. - Type spirvType = getTypeConverter()->convertType(allocType); - - // Insert spv.GlobalVariable for this allocation. - Operation *parent = - SymbolTable::getNearestSymbolTable(operation->getParentOp()); - if (!parent) - return failure(); - Location loc = operation.getLoc(); - spirv::GlobalVariableOp varOp; - { - OpBuilder::InsertionGuard guard(rewriter); - Block &entryBlock = *parent->getRegion(0).begin(); - rewriter.setInsertionPointToStart(&entryBlock); - auto varOps = entryBlock.getOps(); - std::string varName = - std::string("__workgroup_mem__") + - std::to_string(std::distance(varOps.begin(), varOps.end())); - varOp = rewriter.create(loc, spirvType, varName, - /*initializer=*/nullptr); - } - - // Get pointer to global variable at the current scope. - rewriter.replaceOpWithNewOp(operation, varOp); - return success(); - } -}; - -/// Removed a deallocation if it is a supported allocation. Currently only -/// removes deallocation if the memory space is workgroup memory. -class DeallocOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::DeallocOp operation, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MemRefType deallocType = operation.memref().getType().cast(); - if (!isAllocationSupported(deallocType)) - return operation.emitError("unhandled deallocation type"); - rewriter.eraseOp(operation); - return success(); - } -}; - /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public OpConversionPattern { @@ -430,26 +274,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.load to spv.Load. -class IntLoadOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts memref.load to spv.Load. -class LoadOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: @@ -479,26 +303,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.store to spv.Store on integers. -class IntStoreOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts memref.store to spv.Store. -class StoreOpPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.zexti to spv.Select if the type of source is i1 or vector of /// i1. class ZeroExtendI1Pattern final : public OpConversionPattern { @@ -991,119 +795,6 @@ return failure(); } -//===----------------------------------------------------------------------===// -// LoadOp -//===----------------------------------------------------------------------===// - -LogicalResult -IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, - ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); - auto loc = loadOp.getLoc(); - auto memrefType = loadOp.memref().getType().cast(); - if (!memrefType.getElementType().isSignlessInteger()) - return failure(); - - auto &typeConverter = *getTypeConverter(); - spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loc, rewriter); - - int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - bool isBool = srcBits == 1; - if (isBool) - srcBits = typeConverter.getOptions().boolNumBits; - Type pointeeType = typeConverter.convertType(memrefType) - .cast() - .getPointeeType(); - Type structElemType = pointeeType.cast().getElementType(0); - Type dstType; - if (auto arrayType = structElemType.dyn_cast()) - dstType = arrayType.getElementType(); - else - dstType = structElemType.cast().getElementType(); - - int dstBits = dstType.getIntOrFloatBitWidth(); - assert(dstBits % srcBits == 0); - - // If the rewrited load op has the same bit width, use the loading value - // directly. - if (srcBits == dstBits) { - rewriter.replaceOpWithNewOp(loadOp, - accessChainOp.getResult()); - return success(); - } - - // Assume that getElementPtr() works linearizely. If it's a scalar, the method - // still returns a linearized accessing. If the accessing is not linearized, - // there will be offset issues. - assert(accessChainOp.indices().size() == 2); - Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, - srcBits, dstBits, rewriter); - Value spvLoadOp = rewriter.create( - loc, dstType, adjustedPtr, - loadOp->getAttrOfType( - spirv::attributeName()), - loadOp->getAttrOfType("alignment")); - - // Shift the bits to the rightmost. - // ____XXXX________ -> ____________XXXX - Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); - Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); - Value result = rewriter.create( - loc, spvLoadOp.getType(), spvLoadOp, offset); - - // Apply the mask to extract corresponding bits. - Value mask = rewriter.create( - loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); - result = rewriter.create(loc, dstType, result, mask); - - // Apply sign extension on the loading value unconditionally. The signedness - // semantic is carried in the operator itself, we relies other pattern to - // handle the casting. - IntegerAttr shiftValueAttr = - rewriter.getIntegerAttr(dstType, dstBits - srcBits); - Value shiftValue = - rewriter.create(loc, dstType, shiftValueAttr); - result = rewriter.create(loc, dstType, result, - shiftValue); - result = rewriter.create(loc, dstType, result, - shiftValue); - - if (isBool) { - dstType = typeConverter.convertType(loadOp.getType()); - mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); - Value isOne = rewriter.create(loc, result, mask); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - result = rewriter.create(loc, dstType, isOne, one, zero); - } else if (result.getType().getIntOrFloatBitWidth() != - static_cast(dstBits)) { - result = rewriter.create(loc, dstType, result); - } - rewriter.replaceOp(loadOp, result); - - assert(accessChainOp.use_empty()); - rewriter.eraseOp(accessChainOp); - - return success(); -} - -LogicalResult -LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); - auto memrefType = loadOp.memref().getType().cast(); - if (memrefType.getElementType().isSignlessInteger()) - return failure(); - auto loadPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); - rewriter.replaceOpWithNewOp(loadOp, loadPtr); - return success(); -} - //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// @@ -1153,120 +844,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// StoreOp -//===----------------------------------------------------------------------===// - -LogicalResult -IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); - auto memrefType = storeOp.memref().getType().cast(); - if (!memrefType.getElementType().isSignlessInteger()) - return failure(); - - auto loc = storeOp.getLoc(); - auto &typeConverter = *getTypeConverter(); - spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), - storeOperands.indices(), loc, rewriter); - int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); - - bool isBool = srcBits == 1; - if (isBool) - srcBits = typeConverter.getOptions().boolNumBits; - Type pointeeType = typeConverter.convertType(memrefType) - .cast() - .getPointeeType(); - Type structElemType = pointeeType.cast().getElementType(0); - Type dstType; - if (auto arrayType = structElemType.dyn_cast()) - dstType = arrayType.getElementType(); - else - dstType = structElemType.cast().getElementType(); - - int dstBits = dstType.getIntOrFloatBitWidth(); - assert(dstBits % srcBits == 0); - - if (srcBits == dstBits) { - rewriter.replaceOpWithNewOp( - storeOp, accessChainOp.getResult(), storeOperands.value()); - return success(); - } - - // Since there are multi threads in the processing, the emulation will be done - // with atomic operations. E.g., if the storing value is i8, rewrite the - // StoreOp to - // 1) load a 32-bit integer - // 2) clear 8 bits in the loading value - // 3) store 32-bit value back - // 4) load a 32-bit integer - // 5) modify 8 bits in the loading value - // 6) store 32-bit value back - // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step - // 4 to step 6 are done by AtomicOr as another atomic step. - assert(accessChainOp.indices().size() == 2); - Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); - Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); - - // Create a mask to clear the destination. E.g., if it is the second i8 in - // i32, 0xFFFF00FF is created. - Value mask = rewriter.create( - loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); - Value clearBitsMask = - rewriter.create(loc, dstType, mask, offset); - clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - - Value storeVal = storeOperands.value(); - if (isBool) { - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - storeVal = - rewriter.create(loc, dstType, storeVal, one, zero); - } - storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); - Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, - srcBits, dstBits, rewriter); - Optional scope = getAtomicOpScope(memrefType); - if (!scope) - return failure(); - Value result = rewriter.create( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - clearBitsMask); - result = rewriter.create( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - storeVal); - - // The AtomicOrOp has no side effect. Since it is already inserted, we can - // just remove the original StoreOp. Note that rewriter.replaceOp() - // doesn't work because it only accepts that the numbers of result are the - // same. - rewriter.eraseOp(storeOp); - - assert(accessChainOp.use_empty()); - rewriter.eraseOp(accessChainOp); - - return success(); -} - -LogicalResult -StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); - auto memrefType = storeOp.memref().getType().cast(); - if (memrefType.getElementType().isSignlessInteger()) - return failure(); - auto storePtr = - spirv::getElementPtr(*getTypeConverter(), memrefType, - storeOperands.memref(), storeOperands.indices(), - storeOp.getLoc(), rewriter); - rewriter.replaceOpWithNewOp(storeOp, storePtr, - storeOperands.value()); - return success(); -} - //===----------------------------------------------------------------------===// // XorOp //===----------------------------------------------------------------------===// @@ -1343,10 +920,6 @@ // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, - // Memory patterns - AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, - LoadOpPattern, StoreOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern, // Type cast patterns diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir rename from mlir/test/Conversion/StandardToSPIRV/alloc.mlir rename to mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -1,8 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s - -//===----------------------------------------------------------------------===// -// memref allocation/deallocation ops -//===----------------------------------------------------------------------===// +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s module attributes { spv.target_env = #spv.target_env< @@ -26,7 +22,6 @@ // CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[PTR]] // CHECK: spv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32 // CHECK-NOT: memref.dealloc -// CHECK: spv.Return // ----- @@ -75,8 +70,7 @@ // CHECK-SAME: !spv.ptr)>, Workgroup> // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} // CHECK-SAME: !spv.ptr)>, Workgroup> -// CHECK: spv.func @two_allocs() -// CHECK: spv.Return +// CHECK: func @two_allocs() // ----- @@ -96,8 +90,7 @@ // CHECK-SAME: !spv.ptr, stride=8>)>, Workgroup> // CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}} // CHECK-SAME: !spv.ptr, stride=16>)>, Workgroup> -// CHECK: spv.func @two_allocs_vector() -// CHECK: spv.Return +// CHECK: func @two_allocs_vector() // ----- @@ -108,8 +101,7 @@ } { func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) { - // expected-error @+2 {{unhandled allocation type}} - // expected-error @+1 {{'memref.alloc' op operand #0 must be index}} + // expected-error @+1 {{unhandled allocation type}} %0 = memref.alloc(%arg0) : memref<4x?xf32, 3> return } @@ -138,8 +130,7 @@ } { func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { - // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{'memref.dealloc' op operand #0 must be unranked.memref of any type values or memref of any type values}} + // expected-error @+1 {{unhandled deallocation type}} memref.dealloc %arg0 : memref<4x?xf32, 3> return } @@ -153,8 +144,7 @@ } { func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { - // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{op operand #0 must be unranked.memref of any type values or memref of any type values}} + // expected-error @+1 {{unhandled deallocation type}} memref.dealloc %arg0 : memref<4x5xf32> return } diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -0,0 +1,338 @@ +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @load_store_zero_rank_float +func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { + // CHECK: [[ARG0:%.*]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 + %0 = memref.load %arg0[] : memref + // CHECK: [[ARG1:%.*]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 + memref.store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { + // CHECK: [[ARG0:%.*]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 + %0 = memref.load %arg0[] : memref + // CHECK: [[ARG1:%.*]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 + memref.store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: func @load_store_unknown_dim +func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref) { + // CHECK: %[[SRC:.+]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]] + // CHECK: spv.Load "StorageBuffer" %[[AC0]] + %0 = memref.load %source[%i] : memref + // CHECK: %[[DST:.+]] = unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]] + // CHECK: spv.Store "StorageBuffer" %[[AC1]] + memref.store %0, %dest[%i]: memref + return +} + +} // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-32-bit types are +// emulated via 32-bit types. +// TODO: Test i64 types. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @load_i1 +func @load_i1(%arg0: memref) -> i1 { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // Convert to i1 type. + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 + // CHECK: %[[FALSE:.+]] = spv.Constant false + // CHECK: %[[TRUE:.+]] = spv.Constant true + // CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1 + // CHECK: return %[[RES]] + %0 = memref.load %arg0[] : memref + return %0 : i1 +} + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + %0 = memref.load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index) +func @load_i16(%arg0: memref<10xi16>, %index : index) { + // CHECK: %[[ARG1_CAST:.+]] = unrealized_conversion_cast %[[ARG1]] : index to i32 + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1_CAST]] : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 + // CHECK: %[[TWO1:.+]] = spv.Constant 2 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[TWO2:.+]] = spv.Constant 2 : i32 + // CHECK: %[[SIXTEEN:.+]] = spv.Constant 16 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 65535 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 16 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + %0 = memref.load %arg0[%index] : memref<10xi16> + return +} + +// CHECK-LABEL: @load_i32 +func @load_i32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = memref.load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_f32 +func @load_f32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = memref.load %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i1 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) +func @store_i1(%arg0: memref, %value: i1) { + // CHECK: %[[ARG0_CAST:.+]] = unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i8 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) +func @store_i8(%arg0: memref, %value: i8) { + // CHECK: %[[ARG1_CAST:.+]] = unrealized_conversion_cast %[[ARG1]] : i8 to i32 + // CHECK: %[[ARG0_CAST:.+]] = unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i16 +// CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16) +func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { + // CHECK: %[[ARG2_CAST:.+]] = unrealized_conversion_cast %[[ARG2]] : i16 to i32 + // CHECK: %[[ARG0_CAST:.+]] = unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ARG1_CAST:.+]] = unrealized_conversion_cast %[[ARG1]] : index to i32 + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1_CAST]] : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 + // CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 + // CHECK: %[[SIXTEEN:.+]] = spv.Constant 16 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 65535 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[TWO2:.+]] = spv.Constant 2 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[%index] : memref<10xi16> + return +} + +// CHECK-LABEL: @store_i32 +func @store_i32(%arg0: memref, %value: i32) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + memref.store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_f32 +func @store_f32(%arg0: memref, %value: f32) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + memref.store %value, %arg0[] : memref + return +} + +} // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-16/32-bit types +// are emulated via 32-bit types. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + %0 = memref.load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +func @load_i16(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = memref.load %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i8 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) +func @store_i8(%arg0: memref, %value: i8) { + // CHECK: %[[ARG1_CAST:.+]] = unrealized_conversion_cast %[[ARG1]] : i8 to i32 + // CHECK: %[[ARG0_CAST:.+]] = unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i16 +func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + memref.store %value, %arg0[%index] : memref<10xi16> + return +} + +} // end module diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -853,330 +853,6 @@ return } -//===----------------------------------------------------------------------===// -// memref load/store ops -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @load_store_zero_rank_float -// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) -func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { - // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 - // CHECK: spv.AccessChain [[ARG0]][ - // CHECK-SAME: [[ZERO1]], [[ZERO1]] - // CHECK-SAME: ] : - // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 - %0 = memref.load %arg0[] : memref - // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 - // CHECK: spv.AccessChain [[ARG1]][ - // CHECK-SAME: [[ZERO2]], [[ZERO2]] - // CHECK-SAME: ] : - // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 - memref.store %0, %arg1[] : memref - return -} - -// CHECK-LABEL: @load_store_zero_rank_int -// CHECK: [[ARG0:%.*]]: !spv.ptr [0])>, StorageBuffer>, -// CHECK: [[ARG1:%.*]]: !spv.ptr [0])>, StorageBuffer>) -func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { - // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 - // CHECK: spv.AccessChain [[ARG0]][ - // CHECK-SAME: [[ZERO1]], [[ZERO1]] - // CHECK-SAME: ] : - // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 - %0 = memref.load %arg0[] : memref - // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 - // CHECK: spv.AccessChain [[ARG1]][ - // CHECK-SAME: [[ZERO2]], [[ZERO2]] - // CHECK-SAME: ] : - // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 - memref.store %0, %arg1[] : memref - return -} - -// CHECK-LABEL: func @load_store_unknown_dim -// CHECK-SAME: %[[SRC:[a-z0-9]+]]: !spv.ptr [0])>, StorageBuffer>, -// CHECK-SAME: %[[DST:[a-z0-9]+]]: !spv.ptr [0])>, StorageBuffer>) -func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref) { - // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]] - // CHECK: spv.Load "StorageBuffer" %[[AC0]] - %0 = memref.load %source[%i] : memref - // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]] - // CHECK: spv.Store "StorageBuffer" %[[AC1]] - memref.store %0, %dest[%i]: memref - return -} - -} // end module - -// ----- - -// Check that access chain indices are properly adjusted if non-32-bit types are -// emulated via 32-bit types. -// TODO: Test i64 types. -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// CHECK-LABEL: @load_i1 -func @load_i1(%arg0: memref) -> i1 { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] - // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 - // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 - // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 - // CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - // Convert to i1 type. - // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 - // CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 - // CHECK: %[[FALSE:.+]] = spv.Constant false - // CHECK: %[[TRUE:.+]] = spv.Constant true - // CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1 - // CHECK: spv.ReturnValue %[[RES]] : i1 - %0 = memref.load %arg0[] : memref - return %0 : i1 -} - -// CHECK-LABEL: @load_i8 -func @load_i8(%arg0: memref) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] - // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 - // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 - // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 - // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[] : memref - return -} - -// CHECK-LABEL: @load_i16 -// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) -func @load_i16(%arg0: memref<10xi16>, %index : index) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 - // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 - // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 - // CHECK: %[[TWO1:.+]] = spv.Constant 2 : i32 - // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] - // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[TWO2:.+]] = spv.Constant 2 : i32 - // CHECK: %[[SIXTEEN:.+]] = spv.Constant 16 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO2]] : i32 - // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 - // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Constant 65535 : i32 - // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 - // CHECK: %[[T2:.+]] = spv.Constant 16 : i32 - // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 - // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[%index] : memref<10xi16> - return -} - -// CHECK-LABEL: @load_i32 -func @load_i32(%arg0: memref) { - // CHECK-NOT: spv.SDiv - // CHECK: spv.Load - // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref - return -} - -// CHECK-LABEL: @load_f32 -func @load_f32(%arg0: memref) { - // CHECK-NOT: spv.SDiv - // CHECK: spv.Load - // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref - return -} - -// CHECK-LABEL: @store_i1 -// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) -func @store_i1(%arg0: memref, %value: i1) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 - // CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32 - // CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32 - // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] - // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref - return -} - -// CHECK-LABEL: @store_i8 -// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) -func @store_i8(%arg0: memref, %value: i8) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] - // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref - return -} - -// CHECK-LABEL: @store_i16 -// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 - // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 - // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 - // CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 - // CHECK: %[[SIXTEEN:.+]] = spv.Constant 16 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO]] : i32 - // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 - // CHECK: %[[MASK1:.+]] = spv.Constant 65535 : i32 - // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG2]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[TWO2:.+]] = spv.Constant 2 : i32 - // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO2]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] - // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[%index] : memref<10xi16> - return -} - -// CHECK-LABEL: @store_i32 -func @store_i32(%arg0: memref, %value: i32) { - // CHECK: spv.Store - // CHECK-NOT: spv.AtomicAnd - // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[] : memref - return -} - -// CHECK-LABEL: @store_f32 -func @store_f32(%arg0: memref, %value: f32) { - // CHECK: spv.Store - // CHECK-NOT: spv.AtomicAnd - // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[] : memref - return -} - -} // end module - -// ----- - -// Check that access chain indices are properly adjusted if non-16/32-bit types -// are emulated via 32-bit types. -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, {}> -} { - -// CHECK-LABEL: @load_i8 -func @load_i8(%arg0: memref) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] - // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 - // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 - // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 - // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[] : memref - return -} - -// CHECK-LABEL: @load_i16 -func @load_i16(%arg0: memref) { - // CHECK-NOT: spv.SDiv - // CHECK: spv.Load - // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref - return -} - -// CHECK-LABEL: @store_i8 -// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) -func @store_i8(%arg0: memref, %value: i8) { - // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 - // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 - // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] - // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref - return -} - -// CHECK-LABEL: @store_i16 -func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { - // CHECK: spv.Store - // CHECK-NOT: spv.AtomicAnd - // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[%index] : memref<10xi16> - return -} - } // end module // ----- diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1918,6 +1918,7 @@ ":MathToLibm", ":MathToSPIRV", ":MemRefToLLVM", + ":MemRefToSPIRV", ":OpenACCToLLVM", ":OpenACCToSCF", ":OpenMPToLLVM", @@ -2991,6 +2992,7 @@ ":ConversionPassIncGen", ":GPUDialect", ":IR", + ":MemRefToSPIRV", ":Pass", ":SCFDialect", ":SCFToSPIRV", @@ -3686,7 +3688,6 @@ deps = [ ":ConversionPassIncGen", ":IR", - ":MemRefDialect", ":Pass", ":SPIRVConversion", ":SPIRVDialect", @@ -4232,6 +4233,7 @@ ":Affine", ":ConversionPassIncGen", ":IR", + ":MemRefToSPIRV", ":Pass", ":SCFDialect", ":SPIRVConversion", @@ -4358,6 +4360,32 @@ ], ) +cc_library( + name = "MemRefToSPIRV", + srcs = glob([ + "lib/Conversion/MemRefToSPIRV/*.cpp", + "lib/Conversion/MemRefToSPIRV/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/MemRefToSPIRV/*.h", + ]), + includes = [ + "include", + "lib/Conversion/MemRefToSPIRV", + ], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":MemRefDialect", + ":Pass", + ":SPIRVConversion", + ":SPIRVDialect", + ":Support", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "MathToLLVM", srcs = glob(["lib/Conversion/MathToLLVM/*.cpp"]) + ["lib/Conversion/PassDetail.h"], @@ -4982,6 +5010,7 @@ ":MathTransforms", ":MemRefDialect", ":MemRefToLLVM", + ":MemRefToSPIRV", ":MemRefTransforms", ":NVVMDialect", ":OpenACCDialect",