diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -49,20 +49,25 @@ /// values will be packed into one 32-bit value to be memory efficient. bool emulateNon32BitScalarTypes; + /// Use 64-bit integers to convert index types. + bool use64bitIndex; + /// The number of bits to store a boolean value. It is eight bits by /// default. unsigned boolNumBits; - // Note: we need this instead of inline initializers becuase of + // Note: we need this instead of inline initializers because of // https://bugs.llvm.org/show_bug.cgi?id=36684 - Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {} + Options() + : emulateNon32BitScalarTypes(true), use64bitIndex(false), + boolNumBits(8) {} }; explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, Options options = {}); /// Gets the SPIR-V correspondence for the standard index type. - static Type getIndexType(MLIRContext *context); + Type getIndexType() const; /// Returns the corresponding memory space for memref given a SPIR-V storage /// class. @@ -79,6 +84,8 @@ private: spirv::TargetEnv targetEnv; Options options; + + MLIRContext *getContext() const; }; //===----------------------------------------------------------------------===// @@ -129,21 +136,23 @@ /// Returns the value for the given `builtin` variable. This function gets or /// inserts the global variable associated for the builtin within the nearest /// symbol table enclosing `op`. Returns null Value on error. -Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, +Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder); /// Gets the value at the given `offset` of the push constant storage with a -/// total of `elementCount` 32-bit integers. A global variable will be created -/// in the nearest symbol table enclosing `op` for the push constant storage if -/// not existing. Load ops will be created via the given `builder` to load -/// values from the push constant. Returns null Value on error. +/// total of `elementCount` `integerType` integers. A global variable will be +/// created in the nearest symbol table enclosing `op` for the push constant +/// storage if not existing. Load ops will be created via the given `builder` to +/// load values from the push constant. Returns null Value on error. Value getPushConstantValue(Operation *op, unsigned elementCount, - unsigned offset, OpBuilder &builder); + unsigned offset, Type integerType, + OpBuilder &builder); /// Generates IR to perform index linearization with the given `indices` and /// their corresponding `strides`, adding an initial `offset`. Value linearizeIndex(ValueRange indices, ArrayRef strides, - int64_t offset, Location loc, OpBuilder &builder); + int64_t offset, Type integerType, Location loc, + OpBuilder &builder); /// Performs the index computation to get to the element at `indices` of the /// memory pointed to by `basePtr`, using the layout map of `baseType`. diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -135,10 +135,14 @@ if (!index) return failure(); + auto *typeConverter = this->template getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + // SPIR-V invocation builtin variables are a vector of type <3xi32> - auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); + auto spirvBuiltin = + spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerType(32), spirvBuiltin, + op, indexType, spirvBuiltin, rewriter.getI32ArrayAttr({index.getValue()})); return success(); } @@ -148,7 +152,11 @@ SingleDimLaunchConfigConversion::matchAndRewrite( SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); + auto *typeConverter = this->template getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + + auto spirvBuiltin = + spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); rewriter.replaceOp(op, spirvBuiltin); return success(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -26,11 +26,11 @@ /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V /// location invocation ID. This function will create necessary operations with /// `builder` at the proper region containing `op`. -static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc, - OpBuilder *builder) { +static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, + Location loc, OpBuilder *builder) { assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); Value invocation = spirv::getBuiltinVariableValue( - op, spirv::BuiltIn::LocalInvocationId, *builder); + op, spirv::BuiltIn::LocalInvocationId, integerType, *builder); Type xType = invocation.getType().cast().getElementType(); return builder->create( loc, xType, invocation, builder->getI32ArrayAttr({dim})); @@ -137,12 +137,15 @@ Value convertedInput = operands[0], convertedOutput = operands[1]; Location loc = genericOp.getLoc(); + auto *typeConverter = getTypeConverter(); + auto indexType = typeConverter->getIndexType(); + // Get the invocation ID. - Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter); + Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc, + &rewriter); // TODO: Load to Workgroup storage class first. - auto *typeConverter = getTypeConverter(); // Get the input element accessed by this invocation. Value inputElementPtr = spirv::getElementPtr( @@ -164,8 +167,7 @@ #undef CREATE_GROUP_NON_UNIFORM_BIN_OP // Get the output element accessed by this reduction. - Value zero = spirv::ConstantOp::getZero( - typeConverter->getIndexType(rewriter.getContext()), loc, rewriter); + Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); SmallVector zeroIndices(originalOutputType.getRank(), zero); Value outputElementPtr = spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -373,8 +373,11 @@ return failure(); } + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + Value index = spirv::linearizeIndex(adaptor.indices(), strides, - /*offset=*/0, loc, rewriter); + /*offset=*/0, indexType, loc, rewriter); auto acOp = rewriter.create(loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -178,6 +178,9 @@ TypeConverter::SignatureConversion signatureConverter( funcOp.getType().getNumInputs()); + auto &typeConverter = *getTypeConverter(); + auto indexType = typeConverter.getIndexType(); + auto attrName = spirv::getInterfaceVarABIAttrName(); for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { auto abiInfo = funcOp.getArgAttrOfType( @@ -206,7 +209,6 @@ // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. if (argType.value().cast().isScalarOrVector()) { - auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext()); auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -112,15 +112,8 @@ // Type Conversion //===----------------------------------------------------------------------===// -Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { - // Convert to 32-bit integers for now. Might need a way to control this in - // future. - // TODO: It is probably better to make it 64-bit integers. To - // this some support is needed in SPIR-V dialect for Conversion - // instructions. The Vulkan spec requires the builtins like - // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be - // SExtended to 64-bit for index computations. - return IntegerType::get(context, 32); +Type SPIRVTypeConverter::getIndexType() const { + return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); } /// Mapping between SPIR-V storage classes to memref memory spaces. @@ -183,6 +176,10 @@ return options; } +MLIRContext *SPIRVTypeConverter::getContext() const { + return targetEnv.getAttr().getContext(); +} + #undef STORAGE_SPACE_MAP_LIST // TODO: This is a utility function that should probably be exposed by the @@ -505,9 +502,7 @@ // want to validate and convert to be safe. addConversion([](spirv::SPIRVType type) { return type; }); - addConversion([](IndexType indexType) { - return SPIRVTypeConverter::getIndexType(indexType.getContext()); - }); + addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> Optional { if (auto scalarType = intType.dyn_cast()) @@ -630,7 +625,7 @@ /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, - OpBuilder &builder) { + Type integerType, OpBuilder &builder) { if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; @@ -644,9 +639,8 @@ case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { - auto ptrType = spirv::PointerType::get( - VectorType::get({3}, builder.getIntegerType(32)), - spirv::StorageClass::Input); + auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), + spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); @@ -655,8 +649,8 @@ case spirv::BuiltIn::SubgroupId: case spirv::BuiltIn::NumSubgroups: case spirv::BuiltIn::SubgroupSize: { - auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), - spirv::StorageClass::Input); + auto ptrType = + spirv::PointerType::get(integerType, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); @@ -671,6 +665,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, + Type integerType, OpBuilder &builder) { Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { @@ -678,8 +673,9 @@ return nullptr; } - spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( - *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); + spirv::GlobalVariableOp varOp = + getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), + builtin, integerType, builder); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); } @@ -691,10 +687,10 @@ /// Returns the pointer type for the push constant storage containing /// `elementCount` 32-bit integer values. static spirv::PointerType getPushConstantStorageType(unsigned elementCount, - Builder &builder) { - auto arrayType = spirv::ArrayType::get( - SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, - /*stride=*/4); + Builder &builder, + Type indexType) { + auto arrayType = spirv::ArrayType::get(indexType, elementCount, + /*stride=*/4); auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); } @@ -725,19 +721,21 @@ /// `elementCount` 32-bit integer values in `block`. static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, - unsigned elementCount, OpBuilder &b) { + unsigned elementCount, OpBuilder &b, + Type indexType) { if (auto varOp = getPushConstantVariable(block, elementCount)) return varOp; auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); - auto type = getPushConstantStorageType(elementCount, builder); + auto type = getPushConstantStorageType(elementCount, builder, indexType); const char *name = "__push_constant_var__"; return builder.create(loc, type, name, /*initializer=*/nullptr); } Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, - unsigned offset, OpBuilder &builder) { + unsigned offset, Type integerType, + OpBuilder &builder) { Location loc = op->getLoc(); Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { @@ -746,12 +744,11 @@ } spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( - loc, parent->getRegion(0).front(), elementCount, builder); + loc, parent->getRegion(0).front(), elementCount, builder, integerType); - auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); - Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); + Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); Value offsetOp = builder.create( - loc, i32Type, builder.getI32IntegerAttr(offset)); + loc, integerType, builder.getI32IntegerAttr(offset)); auto addrOp = builder.create(loc, varOp); auto acOp = builder.create( loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); @@ -763,23 +760,22 @@ //===----------------------------------------------------------------------===// Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, - int64_t offset, Location loc, - OpBuilder &builder) { + int64_t offset, Type integerType, + Location loc, OpBuilder &builder) { assert(indices.size() == strides.size() && "must provide indices for all dimensions"); - auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext()); - // TODO: Consider moving to use affine.apply and patterns converting // affine.apply to standard ops. This needs converting to SPIR-V passes to be // broken down into progressive small steps so we can have intermediate steps // using other dialects. At the moment SPIR-V is the final sink. Value linearizedIndex = builder.create( - loc, indexType, IntegerAttr::get(indexType, offset)); + loc, integerType, IntegerAttr::get(integerType, offset)); for (auto index : llvm::enumerate(indices)) { Value strideVal = builder.create( - loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + loc, integerType, + IntegerAttr::get(integerType, strides[index.index()])); Value update = builder.create(loc, strideVal, index.value()); linearizedIndex = builder.create(loc, linearizedIndex, update); @@ -800,7 +796,7 @@ return nullptr; } - auto indexType = typeConverter.getIndexType(builder.getContext()); + auto indexType = typeConverter.getIndexType(); SmallVector linearizedIndices; auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); @@ -812,7 +808,7 @@ linearizedIndices.push_back(zero); } else { linearizedIndices.push_back( - linearizeIndex(indices, strides, offset, loc, builder)); + linearizeIndex(indices, strides, offset, indexType, loc, builder)); } return builder.create(loc, basePtr, linearizedIndices); }