diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -53,16 +53,22 @@ /// default. unsigned boolNumBits; + /// Use 64-bit integer type for index calculations and get***id() builtins. + /// 32-bit integer type will be used if not set. + bool use64bitIndex; + // Note: we need this instead of inline initializers becuase of // https://bugs.llvm.org/show_bug.cgi?id=36684 - Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {} + Options() + : emulateNon32BitScalarTypes(true), boolNumBits(8), + use64bitIndex(false) {} }; explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, Options options = {}); /// Gets the SPIR-V correspondence for the standard index type. - static Type getIndexType(MLIRContext *context); + Type getIndexType() const; /// Returns the corresponding memory space for memref given a SPIR-V storage /// class. @@ -76,6 +82,8 @@ /// Returns the options controlling the SPIR-V type converter. const Options &getOptions() const; + MLIRContext *getContext() const; + private: spirv::TargetEnv targetEnv; Options options; @@ -130,7 +138,7 @@ /// inserts the global variable associated for the builtin within the nearest /// symbol table enclosing `op`. Returns null Value on error. Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, - OpBuilder &builder); + OpBuilder &builder, Type indexType); /// Gets the value at the given `offset` of the push constant storage with a /// total of `elementCount` 32-bit integers. A global variable will be created @@ -138,12 +146,13 @@ /// not existing. Load ops will be created via the given `builder` to load /// values from the push constant. Returns null Value on error. Value getPushConstantValue(Operation *op, unsigned elementCount, - unsigned offset, OpBuilder &builder); + unsigned offset, OpBuilder &builder, Type indexType); /// Generates IR to perform index linearization with the given `indices` and /// their corresponding `strides`, adding an initial `offset`. Value linearizeIndex(ValueRange indices, ArrayRef strides, - int64_t offset, Location loc, OpBuilder &builder); + int64_t offset, Location loc, OpBuilder &builder, + Type indexType); /// Performs the index computation to get to the element at `indices` of the /// memory pointed to by `basePtr`, using the layout map of `baseType`. @@ -154,7 +163,7 @@ spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, - OpBuilder &builder); + OpBuilder &builder, Type indexType); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -135,10 +135,14 @@ if (!index) return failure(); + auto *typeConverter = getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + // SPIR-V invocation builtin variables are a vector of type <3xi32> - auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); + auto spirvBuiltin = + spirv::getBuiltinVariableValue(op, builtin, rewriter, indexType); rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerType(32), spirvBuiltin, + op, indexType, spirvBuiltin, rewriter.getI32ArrayAttr({index.getValue()})); return success(); } @@ -148,7 +152,11 @@ SingleDimLaunchConfigConversion::matchAndRewrite( SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); + auto *typeConverter = getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + + auto spirvBuiltin = + spirv::getBuiltinVariableValue(op, builtin, rewriter, indexType); rewriter.replaceOp(op, spirvBuiltin); return success(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -27,10 +27,10 @@ /// location invocation ID. This function will create necessary operations with /// `builder` at the proper region containing `op`. static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc, - OpBuilder *builder) { + OpBuilder *builder, Type indexType) { assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); Value invocation = spirv::getBuiltinVariableValue( - op, spirv::BuiltIn::LocalInvocationId, *builder); + op, spirv::BuiltIn::LocalInvocationId, *builder, indexType); Type xType = invocation.getType().cast().getElementType(); return builder->create( loc, xType, invocation, builder->getI32ArrayAttr({dim})); @@ -137,16 +137,20 @@ Value convertedInput = operands[0], convertedOutput = operands[1]; Location loc = genericOp.getLoc(); + auto *typeConverter = getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + // Get the invocation ID. - Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter); + Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter, + indexType); // TODO: Load to Workgroup storage class first. - auto *typeConverter = getTypeConverter(); // Get the input element accessed by this invocation. - Value inputElementPtr = spirv::getElementPtr( - *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); + Value inputElementPtr = + spirv::getElementPtr(*typeConverter, originalInputType, convertedInput, + {x}, loc, rewriter, indexType); Value inputElement = rewriter.create(loc, inputElementPtr); // Perform the group reduction operation. @@ -164,12 +168,11 @@ #undef CREATE_GROUP_NON_UNIFORM_BIN_OP // Get the output element accessed by this reduction. - Value zero = spirv::ConstantOp::getZero( - typeConverter->getIndexType(rewriter.getContext()), loc, rewriter); + Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); SmallVector zeroIndices(originalOutputType.getRank(), zero); Value outputElementPtr = spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, - zeroIndices, loc, rewriter); + zeroIndices, loc, rewriter, indexType); // Write out the final reduction result. This should be only conducted by one // invocation. We use spv.GroupNonUniformElect to find the invocation with the diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -284,9 +284,11 @@ return failure(); auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loc, rewriter); + loadOperands.indices(), loc, rewriter, indexType); if (!accessChainOp) return failure(); @@ -378,9 +380,14 @@ auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); - auto loadPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + + auto loadPtr = + spirv::getElementPtr(*getTypeConverter(), memrefType, + loadOperands.memref(), loadOperands.indices(), + loadOp.getLoc(), rewriter, indexType); if (!loadPtr) return failure(); @@ -400,9 +407,11 @@ auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), - storeOperands.indices(), loc, rewriter); + storeOperands.indices(), loc, rewriter, indexType); if (!accessChainOp) return failure(); @@ -494,10 +503,14 @@ auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); + + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + auto storePtr = spirv::getElementPtr(*getTypeConverter(), memrefType, storeOperands.memref(), storeOperands.indices(), - storeOp.getLoc(), rewriter); + storeOp.getLoc(), rewriter, indexType); if (!storePtr) return failure(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -373,8 +373,11 @@ return failure(); } + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + Value index = spirv::linearizeIndex(adaptor.indices(), strides, - /*offset=*/0, loc, rewriter); + /*offset=*/0, loc, rewriter, indexType); auto acOp = rewriter.create(loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -178,6 +178,9 @@ TypeConverter::SignatureConversion signatureConverter( funcOp.getType().getNumInputs()); + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + auto attrName = spirv::getInterfaceVarABIAttrName(); for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { auto abiInfo = funcOp.getArgAttrOfType( @@ -206,7 +209,6 @@ // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. if (argType.value().cast().isScalarOrVector()) { - auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext()); auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -112,15 +112,8 @@ // Type Conversion //===----------------------------------------------------------------------===// -Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { - // Convert to 32-bit integers for now. Might need a way to control this in - // future. - // TODO: It is probably better to make it 64-bit integers. To - // this some support is needed in SPIR-V dialect for Conversion - // instructions. The Vulkan spec requires the builtins like - // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be - // SExtended to 64-bit for index computations. - return IntegerType::get(context, 32); +Type SPIRVTypeConverter::getIndexType() const { + return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); } /// Mapping between SPIR-V storage classes to memref memory spaces. @@ -183,6 +176,10 @@ return options; } +MLIRContext *SPIRVTypeConverter::getContext() const { + return targetEnv.getAttr().getContext(); +} + #undef STORAGE_SPACE_MAP_LIST // TODO: This is a utility function that should probably be exposed by the @@ -505,9 +502,7 @@ // want to validate and convert to be safe. addConversion([](spirv::SPIRVType type) { return type; }); - addConversion([](IndexType indexType) { - return SPIRVTypeConverter::getIndexType(indexType.getContext()); - }); + addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> Optional { if (auto scalarType = intType.dyn_cast()) @@ -630,7 +625,7 @@ /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, - OpBuilder &builder) { + OpBuilder &builder, Type indexType) { if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; @@ -644,9 +639,8 @@ case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { - auto ptrType = spirv::PointerType::get( - VectorType::get({3}, builder.getIntegerType(32)), - spirv::StorageClass::Input); + auto ptrType = spirv::PointerType::get(VectorType::get({3}, indexType), + spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); @@ -655,8 +649,8 @@ case spirv::BuiltIn::SubgroupId: case spirv::BuiltIn::NumSubgroups: case spirv::BuiltIn::SubgroupSize: { - auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), - spirv::StorageClass::Input); + auto ptrType = + spirv::PointerType::get(indexType, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); @@ -671,7 +665,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, - OpBuilder &builder) { + OpBuilder &builder, Type indexType) { Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); @@ -679,7 +673,7 @@ } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( - *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); + *parent->getRegion(0).begin(), op->getLoc(), builtin, builder, indexType); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); } @@ -691,10 +685,10 @@ /// Returns the pointer type for the push constant storage containing /// `elementCount` 32-bit integer values. static spirv::PointerType getPushConstantStorageType(unsigned elementCount, - Builder &builder) { - auto arrayType = spirv::ArrayType::get( - SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, - /*stride=*/4); + Builder &builder, + Type indexType) { + auto arrayType = spirv::ArrayType::get(indexType, elementCount, + /*stride=*/4); auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); } @@ -725,19 +719,21 @@ /// `elementCount` 32-bit integer values in `block`. static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, - unsigned elementCount, OpBuilder &b) { + unsigned elementCount, OpBuilder &b, + Type indexType) { if (auto varOp = getPushConstantVariable(block, elementCount)) return varOp; auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); - auto type = getPushConstantStorageType(elementCount, builder); + auto type = getPushConstantStorageType(elementCount, builder, indexType); const char *name = "__push_constant_var__"; return builder.create(loc, type, name, /*initializer=*/nullptr); } Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, - unsigned offset, OpBuilder &builder) { + unsigned offset, OpBuilder &builder, + Type indexType) { Location loc = op->getLoc(); Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { @@ -746,12 +742,11 @@ } spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( - loc, parent->getRegion(0).front(), elementCount, builder); + loc, parent->getRegion(0).front(), elementCount, builder, indexType); - auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); - Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); + Value zeroOp = spirv::ConstantOp::getZero(indexType, loc, builder); Value offsetOp = builder.create( - loc, i32Type, builder.getI32IntegerAttr(offset)); + loc, indexType, builder.getI32IntegerAttr(offset)); auto addrOp = builder.create(loc, varOp); auto acOp = builder.create( loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); @@ -764,12 +759,10 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, int64_t offset, Location loc, - OpBuilder &builder) { + OpBuilder &builder, Type indexType) { assert(indices.size() == strides.size() && "must provide indices for all dimensions"); - auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext()); - // TODO: Consider moving to use affine.apply and patterns converting // affine.apply to standard ops. This needs converting to SPIR-V passes to be // broken down into progressive small steps so we can have intermediate steps @@ -789,7 +782,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ValueRange indices, Location loc, OpBuilder &builder) { + ValueRange indices, Location loc, OpBuilder &builder, Type indexType) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; @@ -800,8 +793,6 @@ return nullptr; } - auto indexType = typeConverter.getIndexType(builder.getContext()); - SmallVector linearizedIndices; auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); @@ -812,7 +803,7 @@ linearizedIndices.push_back(zero); } else { linearizedIndices.push_back( - linearizeIndex(indices, strides, offset, loc, builder)); + linearizeIndex(indices, strides, offset, loc, builder, indexType)); } return builder.create(loc, basePtr, linearizedIndices); }