diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -41,6 +41,12 @@ public: explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr); + /// Gets the number of bytes used for a type when converted to SPIR-V + /// type. Note that it doesnt account for whether the type is legal for a + /// SPIR-V target (described by spirv::TargetEnvAttr). Returns null on + /// failure. + static Optional getConvertedTypeNumBytes(Type t); + /// Gets the SPIR-V correspondence for the standard index type. static Type getIndexType(MLIRContext *context); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -147,14 +147,24 @@ } /// Returns the shifted `targetBits`-bit value with the given offset. -Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { +static Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } +/// Returns true if the allocations of type `t` can be lowered to SPIR-V. +static bool isAllocationSupported(MemRefType t) { + // Currently only support workgroup local memory allocations with static + // shape and int or float element type. + return t.hasStaticShape() && + SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Workgroup) == t.getMemorySpace() && + t.getElementType().isIntOrFloat(); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -165,6 +175,95 @@ namespace { +/// Converts an allocation operation to SPIR-V. Currently only supports lowering +/// to Workgroup memory when the size is constant. +class AllocOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(AllocOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MemRefType allocType = operation.getType(); + if (!isAllocationSupported(allocType)) + return operation.emitError("unhandled allocation type"); + + // Convert the alloc to a 1-D array of the same number of bytes. The element + // type might get converted to a different bitwidth type based on the target + // environment. Adjust the number of elements accordingly. + Optional typeSize = + SPIRVTypeConverter::getConvertedTypeNumBytes(allocType); + if (!typeSize) + return operation.emitError("unable to get converted type size"); + + Type elementType = allocType.getElementType(); + Optional elementTypeBytes = + SPIRVTypeConverter::getConvertedTypeNumBytes(elementType); + int64_t numElements = typeSize.getValue() / *elementTypeBytes; + + Type convertedElementType = typeConverter.convertType(elementType); + if (!convertedElementType) { + return operation.emitError( + "unable to convert element type of allocation"); + } + // If the element type is not supported natively, further divide the number + // of elements by (convertedElementTypeSize / elementTypeSize). This will + // keep the number of bytes allocated consistent. + Optional convertedElementTypeBytes = + SPIRVTypeConverter::getConvertedTypeNumBytes(convertedElementType); + if (!convertedElementTypeBytes || + *convertedElementTypeBytes % *elementTypeBytes) + return operation.emitError("unable to find converted allocation size"); + numElements /= (*convertedElementTypeBytes / *elementTypeBytes); + + // Get the SPIR-V type for the allocation. + Type spirvType = typeConverter.convertType( + MemRefType::get(numElements, convertedElementType, + allocType.getAffineMaps(), allocType.getMemorySpace())); + + // Insert spv.global_variable for this allocation. + Operation *parent = + SymbolTable::getNearestSymbolTable(operation.getParentOp()); + if (!parent) + return failure(); + Location loc = operation.getLoc(); + spirv::GlobalVariableOp varOp; + { + OpBuilder::InsertionGuard guard(rewriter); + Block &entryBlock = *parent->getRegion(0).begin(); + rewriter.setInsertionPointToStart(&entryBlock); + auto varOps = entryBlock.getOps(); + std::string varName = + std::string("__workgroup_mem__") + + std::to_string(std::distance(varOps.begin(), varOps.end())); + varOp = rewriter.create( + loc, TypeAttr::get(spirvType), varName, + /*initializer = */ nullptr); + } + + // Get pointer to global variable at the current scope. + rewriter.replaceOpWithNewOp(operation, varOp); + return success(); + } +}; + +/// Removed a deallocation if it is a supported allocation. Currently only +/// removes deallocation if the memory space is workgroup memory. +class DeallocOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(DeallocOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MemRefType deallocType = operation.memref().getType().cast(); + if (!isAllocationSupported(deallocType)) + return operation.emitError("unhandled deallocation type"); + rewriter.eraseOp(operation); + return success(); + } +}; + /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public SPIRVOpLowering { @@ -862,6 +961,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, + AllocOpPattern, DeallocOpPattern, BitwiseOpPattern, BitwiseOpPattern, BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -218,6 +218,10 @@ return llvm::None; } +Optional SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) { + return getTypeNumBytes(t); +} + /// Converts a scalar `type` to a suitable type under the given `targetEnv`. static Optional convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, @@ -574,35 +578,40 @@ SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ArrayRef indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. + int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || - llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || + offset == MemRefType::getDynamicStrideOrOffset()) { return nullptr; } auto indexType = typeConverter.getIndexType(builder.getContext()); - - Value ptrLoc = nullptr; - assert(indices.size() == strides.size() && - "must provide indices for all dimensions"); - for (auto index : enumerate(indices)) { - Value strideVal = builder.create( - loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); - Value update = builder.create(loc, strideVal, index.value()); - ptrLoc = - (ptrLoc ? builder.create(loc, ptrLoc, update).getResult() - : update); - } SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); linearizedIndices.push_back(zero); - // If it is a zero-rank memref type, extract the element directly. - if (!ptrLoc) { - ptrLoc = zero; + + if (baseType.getRank() == 0) { + linearizedIndices.push_back(zero); + } else { + // TODO: Instead of this logic, use affine.apply and add patterns for + // lowering affine.apply to standard ops. These will get lowered to SPIR-V + // ops by the DialectConversion framework. + Value ptrLoc = builder.create( + loc, indexType, IntegerAttr::get(indexType, offset)); + assert(indices.size() == strides.size() && + "must provide indices for all dimensions"); + for (auto index : enumerate(indices)) { + Value strideVal = builder.create( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value update = + builder.create(loc, strideVal, index.value()); + ptrLoc = builder.create(loc, ptrLoc, update); + } + linearizedIndices.push_back(ptrLoc); } - linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -58,13 +58,15 @@ %12 = addi %arg3, %0 : index // CHECK: %[[INDEX2:.*]] = spv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]] %13 = addi %arg4, %3 : index + // CHECK: %[[ZERO:.*]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET1_0:.*]] = spv.constant 0 : i32 // CHECK: %[[STRIDE1_1:.*]] = spv.constant 4 : i32 - // CHECK: %[[OFFSET1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32 + // CHECK: %[[UPDATE1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32 + // CHECK: %[[OFFSET1_1:.*]] = spv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32 // CHECK: %[[STRIDE1_2:.*]] = spv.constant 1 : i32 // CHECK: %[[UPDATE1_2:.*]] = spv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32 // CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32 - // CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32 - // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO1]], %[[OFFSET1_2]]{{\]}} + // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}} // CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]] %14 = load %arg0[%12, %13] : memref<12x4xf32> // CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}} diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir --- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir @@ -28,13 +28,17 @@ // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] // CHECK: ^[[BODY]]: - // CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32 - // CHECK: %[[INDEX1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32 // CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET1:.*]] = spv.constant 0 : i32 + // CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32 + // CHECK: %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32 // CHECK: spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}} - // CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32 - // CHECK: %[[INDEX2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32 // CHECK: %[[ZERO2:.*]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET2:.*]] = spv.constant 0 : i32 + // CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32 + // CHECK: %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32 // CHECK: spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]] // CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32) diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -0,0 +1,129 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +//===----------------------------------------------------------------------===// +// std allocation/deallocation ops +//===----------------------------------------------------------------------===// + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<4x5xf32, 3> + %1 = load %0[%arg0, %arg1] : memref<4x5xf32, 3> + store %1, %0[%arg0, %arg1] : memref<4x5xf32, 3> + dealloc %0 : memref<4x5xf32, 3> + return + } +} +// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr [0]>, Workgroup> +// CHECK: func @alloc_dealloc_workgroup_mem +// CHECK-NOT: alloc +// CHECK: %[[PTR:.+]] = spv._address_of @[[VAR]] +// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[PTR]] +// CHECK: %[[VAL:.+]] = spv.Load "Workgroup" %[[LOADPTR]] : f32 +// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[PTR]] +// CHECK: spv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32 +// CHECK-NOT: dealloc +// CHECK: spv.Return + +// ----- + +// TODO: Uncomment this test when the extension handling correctly +// converts an i16 type to i32 type and handles the load/stores +// correctly. + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<4x5xi16, 3> + %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3> + store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> + dealloc %0 : memref<4x5xi16, 3> + return + } +} +// CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr [0]>, Workgroup> +// CHECK: func @alloc_dealloc_workgroup_mem +// CHECK-NOT: alloc +// CHECK: %[[PTR:.+]] = spv._address_of @[[VAR]] +// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[PTR]] +// CHECK-NOT: dealloc +// CHECK: spv.Return + + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) { + // expected-error @+2 {{unhandled allocation type}} + // expected-error @+1 {{'std.alloc' op operand #0 must be index, but got 'i32'}} + %0 = alloc(%arg0) : memref<4x?xf32, 3> + return + } +} + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_mem() { + // expected-error @+1 {{unhandled allocation type}} + %0 = alloc() : memref<4x5xf32> + return + } +} + + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { + // expected-error @+2 {{unhandled deallocation type}} + // expected-error @+1 {{'std.dealloc' op operand #0 must be memref of any type values, but got '!spv.ptr [0]>, Workgroup>'}} + dealloc %arg0 : memref<4x?xf32, 3> + return + } +} + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { + // expected-error @+2 {{unhandled deallocation type}} + // expected-error @+1 {{op operand #0 must be memref of any type values, but got '!spv.ptr [0]>, StorageBuffer>'}} + dealloc %arg0 : memref<4x5xf32> + return + } +} diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -725,9 +725,11 @@ // CHECK-LABEL: @load_i16 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) func @load_i16(%arg0: memref<10xi16>, %index : index) { - // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 - // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET:.+]] = spv.constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 // CHECK: %[[TWO1:.+]] = spv.constant 2 : i32 // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32 // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] @@ -786,9 +788,11 @@ // CHECK-LABEL: @store_i16 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { - // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 - // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[OFFSET:.+]] = spv.constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 + // CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32 // CHECK: %[[TWO:.+]] = spv.constant 2 : i32 // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32 // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32