diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3194,6 +3194,8 @@ def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>; def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; +def SPV_OC_OpPtrAccessChain : I32EnumAttrCase<"OpPtrAccessChain", 67>; +def SPV_OC_OpInBoundsPtrAccessChain : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>; @@ -3340,10 +3342,10 @@ SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, - SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, - SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle, - SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpPtrAccessChain, + SPV_OC_OpInBoundsPtrAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, + SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic, + SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -137,6 +137,55 @@ // ----- +def SPV_InBoundsPtrAccessChainOp : SPV_Op<"InBoundsPtrAccessChain", [NoSideEffect]> { + let summary = [{ + Has the same semantics as OpPtrAccessChain, with the addition that the + resulting pointer is known to point within the base object. + }]; + + let description = [{ + + + + + ``` + access-chain-op ::= ssa-id `=` `spv.InBoundsPtrAccessChain` ssa-use + `[` ssa-use (',' ssa-use)* `]` + `:` pointer-type + ```mlir + + #### Example: + + ``` + func @inbounds_ptr_access_chain(%arg0: !spv.ptr, %arg1 : i64) -> () { + %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr, i64 + ... + } + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Addresses]> + ]; + + let arguments = (ins + SPV_AnyPtr:$base_ptr, + SPV_Integer:$element, + Variadic:$indices + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>]; +} + +// ----- + def SPV_LoadOp : SPV_Op<"Load", []> { let summary = "Load through a pointer."; @@ -191,6 +240,78 @@ // ----- +def SPV_PtrAccessChainOp : SPV_Op<"PtrAccessChain", [NoSideEffect]> { + let summary = [{ + Has the same semantics as OpAccessChain, with the addition of the + Element operand. + }]; + + let description = [{ + Element is used to do an initial dereference of Base: Base is treated as + the address of an element in an array, and a new element address is + computed from Base and Element to become the OpAccessChain Base to + dereference as per OpAccessChain. This computed Base has the same type + as the originating Base. + + To compute the new element address, Element is treated as a signed count + of elements E, relative to the original Base element B, and the address + of element B + E is computed using enough precision to avoid overflow + and underflow. For objects in the Uniform, StorageBuffer, or + PushConstant storage classes, the element's address or location is + calculated using a stride, which will be the Base-type's Array Stride if + the Base type is decorated with ArrayStride. For all other objects, the + implementation calculates the element's address or location. + + With one exception, undefined behavior results when B + E is not an + element in the same array (same innermost array, if array types are + nested) as B. The exception being when B + E = L, where L is the length + of the array: the address computation for element L is done with the + same stride as any other B + E computation that stays within the array. + + Note: If Base is typed to be a pointer to an array and the desired + operation is to select an element of that array, OpAccessChain should be + directly used, as its first Index selects the array element. + + + + ``` + [access-chain-op ::= ssa-id `=` `spv.PtrAccessChain` ssa-use + `[` ssa-use (',' ssa-use)* `]` + `:` pointer-type + ```mlir + + #### Example: + + ``` + func @ptr_access_chain(%arg0: !spv.ptr, %arg1 : i64) -> () { + %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr, i64 + ... + } + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Addresses, SPV_C_PhysicalStorageBufferAddresses, SPV_C_VariablePointers, SPV_C_VariablePointersStorageBuffer]> + ]; + + let arguments = (ins + SPV_AnyPtr:$base_ptr, + SPV_Integer:$element, + Variadic:$indices + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>]; +} + +// ----- + def SPV_StoreOp : SPV_Op<"Store", []> { let summary = "Store through a pointer."; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1019,37 +1019,41 @@ return success(); } +template +static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { + printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices + << "] : " << op.base_ptr().getType() << ", " << indices.getTypes(); +} + static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { - printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr() - << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", " - << op.indices().getTypes(); + printAccessChain(op, op.indices(), printer); } -static LogicalResult verify(spirv::AccessChainOp accessChainOp) { - SmallVector indices(accessChainOp.indices().begin(), - accessChainOp.indices().end()); +template +static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(), indices, accessChainOp.getLoc()); - if (!resultType) { + if (!resultType) return failure(); - } auto providedResultType = accessChainOp.getType().dyn_cast(); - if (!providedResultType) { + if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") << providedResultType; - } - if (resultType != providedResultType) { + if (resultType != providedResultType) return accessChainOp.emitOpError("invalid result type: expected ") << resultType << ", but provided " << providedResultType; - } return success(); } +static LogicalResult verify(spirv::AccessChainOp accessChainOp) { + return verifyAccessChain(accessChainOp, accessChainOp.indices()); +} + //===----------------------------------------------------------------------===// // spv.mlir.addressof //===----------------------------------------------------------------------===// @@ -3770,6 +3774,109 @@ return success(); } +static ParseResult parsePtrAccessChainOpImpl(StringRef opName, + OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType ptrInfo; + SmallVector indicesInfo; + Type type; + auto loc = parser.getCurrentLocation(); + SmallVector indicesTypes; + + if (parser.parseOperand(ptrInfo) || + parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || + parser.parseColonType(type) || + parser.resolveOperand(ptrInfo, type, state.operands)) + return failure(); + + // Check that the provided indices list is not empty before parsing their + // type list. + if (indicesInfo.empty()) + return emitError(state.location) << opName << " expected element"; + + if (parser.parseComma() || parser.parseTypeList(indicesTypes)) + return failure(); + + // Check that the indices types list is not empty and that it has a one-to-one + // mapping to the provided indices. + if (indicesTypes.size() != indicesInfo.size()) + return emitError(state.location) + << opName + << " indices types' count must be equal to indices info count"; + + if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) + return failure(); + + auto resultType = getElementPtrType( + type, llvm::makeArrayRef(state.operands).drop_front(2), state.location); + if (!resultType) + return failure(); + + state.addTypes(resultType); + return success(); +} + +template +static auto concatElemAndIndices(Op op) { + SmallVector ret(op.indices().size() + 1); + ret[0] = op.element(); + llvm::copy(op.indices(), ret.begin() + 1); + return ret; +} + +//===----------------------------------------------------------------------===// +// spv.InBoundsPtrAccessChainOp +//===----------------------------------------------------------------------===// + +void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, + OperationState &state, + Value basePtr, Value element, + ValueRange indices) { + auto type = getElementPtrType(basePtr.getType(), indices, state.location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, element, indices); +} + +static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser, + OperationState &state) { + return parsePtrAccessChainOpImpl( + spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state); +} + +static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) { + printAccessChain(op, concatElemAndIndices(op), printer); +} + +static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) { + return verifyAccessChain(accessChainOp, accessChainOp.indices()); +} + +//===----------------------------------------------------------------------===// +// spv.PtrAccessChainOp +//===----------------------------------------------------------------------===// + +void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, + Value basePtr, Value element, + ValueRange indices) { + auto type = getElementPtrType(basePtr.getType(), indices, state.location); + assert(type && "Unable to deduce return type based on basePtr and indices"); + build(builder, state, type, basePtr, element, indices); +} + +static ParseResult parsePtrAccessChainOp(OpAsmParser &parser, + OperationState &state) { + return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), + parser, state); +} + +static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) { + printAccessChain(op, concatElemAndIndices(op), printer); +} + +static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) { + return verifyAccessChain(accessChainOp, accessChainOp.indices()); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir @@ -628,3 +628,33 @@ spv.Return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.PtrAccessChain +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @ptr_access_chain1( +// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr, +// CHECK-SAME: %[[ARG1:.*]]: i64) +// CHECK: spv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr, i64 +func @ptr_access_chain1(%arg0: !spv.ptr, %arg1 : i64) -> () { + %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr, i64 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.InBoundsPtrAccessChain +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @inbounds_ptr_access_chain1( +// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr, +// CHECK-SAME: %[[ARG1:.*]]: i64) +// CHECK: spv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr, i64 +func @inbounds_ptr_access_chain1(%arg0: !spv.ptr, %arg1 : i64) -> () { + %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr, i64 + return +}