diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -723,12 +723,6 @@ //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - if (indices.empty()) { - emitError(baseLoc, "'spv.AccessChain' op expected at least " - "one index "); - return nullptr; - } - auto ptrType = type.dyn_cast(); if (!ptrType) { emitError(baseLoc, "'spv.AccessChain' op expected a pointer " @@ -791,19 +785,37 @@ OpAsmParser::OperandType ptrInfo; SmallVector indicesInfo; Type type; - // TODO(denis0x0D): regarding to the spec an index must be any integer type, - // figure out how to use resolveOperand with a range of types and do not - // fail on first attempt. - Type indicesType = parser.getBuilder().getIntegerType(32); + auto loc = parser.getCurrentLocation(); + SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || - parser.resolveOperand(ptrInfo, type, state.operands) || - parser.resolveOperands(indicesInfo, indicesType, state.operands)) { + parser.resolveOperand(ptrInfo, type, state.operands)) { return failure(); } + // Check that the provided indices list is not empty before parsing their + // type list. + if (indicesInfo.empty()) { + return emitError(state.location, "'spv.AccessChain' op expected at " + "least one index "); + } + + if (parser.parseComma() || parser.parseTypeList(indicesTypes)) + return failure(); + + // Check that the indices types list is not empty and that it has a one-to-one + // mapping to the provided indices. + if (indicesTypes.size() != indicesInfo.size()) { + return emitError(state.location, "'spv.AccessChain' op indices " + "types' count must be equal to indices " + "info count"); + } + + if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) + return failure(); + auto resultType = getElementPtrType( type, llvm::makeArrayRef(state.operands).drop_front(), state.location); if (!resultType) { @@ -816,7 +828,8 @@ static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr() - << '[' << op.indices() << "] : " << op.base_ptr().getType(); + << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", " + << op.indices().getTypes(); } static LogicalResult verify(spirv::AccessChainOp accessChainOp) { diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir --- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir +++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir @@ -11,7 +11,7 @@ %0 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> %2 = spv.constant 0 : i32 %3 = spv._address_of @kernel_arg_0 : !spv.ptr [0]>, StorageBuffer> - %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %0[%2, %2] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %5 = spv.Load "StorageBuffer" %4 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir @@ -2,8 +2,8 @@ spv.module Logical GLSL450 requires #spv.vce { spv.func @array_stride(%arg0 : !spv.ptr, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr, stride=128>, StorageBuffer> - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr, stride=128>, StorageBuffer> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr, stride=128>, StorageBuffer>, i32, i32 + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr, stride=128>, StorageBuffer>, i32, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -95,8 +95,8 @@ // CHECK-LABEL: @cooperative_matrix_access_chain spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { %0 = spv.constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 spv.ReturnValue %1 : !spv.ptr } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -58,7 +58,7 @@ spv.func @memory_accesses(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: loc({{".*debug.mlir"}}:61:10) - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer> + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer>, i32, i32 // CHECK: loc({{".*debug.mlir"}}:63:10) %3 = spv.Load "StorageBuffer" %2 : f32 // CHECK: loc({{.*debug.mlir"}}:65:5) diff --git a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir @@ -30,7 +30,7 @@ %0 = spv._address_of @globalInvocationID : !spv.ptr, Input> %1 = spv.constant 0: i32 // CHECK: spv.AccessChain %[[ADDR]] - %2 = spv.AccessChain %0[%1] : !spv.ptr, Input> + %2 = spv.AccessChain %0[%1] : !spv.ptr, Input>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -65,9 +65,9 @@ spv.func @loop_kernel() "None" { %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0 : i32 - %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer>, i32 %3 = spv._address_of @GV2 : !spv.ptr [0]>, StorageBuffer> - %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer> + %5 = spv.AccessChain %3[%1] : !spv.ptr [0]>, StorageBuffer>, i32 %6 = spv.constant 4 : i32 %7 = spv.constant 42 : i32 %8 = spv.constant 2 : i32 @@ -84,9 +84,9 @@ spv.BranchConditional %10, ^body, ^merge // CHECK-NEXT: ^bb2: // pred: ^bb1 ^body: - %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer> + %11 = spv.AccessChain %2[%9] : !spv.ptr, StorageBuffer>, i32 %12 = spv.Load "StorageBuffer" %11 : f32 - %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer> + %13 = spv.AccessChain %5[%9] : !spv.ptr, StorageBuffer>, i32 spv.Store "StorageBuffer" %13, %12 : f32 // CHECK: %[[ADD:.*]] = spv.IAdd %14 = spv.IAdd %9, %8 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: @matrix_access_chain spv.func @matrix_access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32) -> !spv.ptr, Function> "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> - %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> + %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr>,Function>, i32 spv.ReturnValue %0 : !spv.ptr, Function> } @@ -20,6 +20,7 @@ // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>> + } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -18,8 +18,8 @@ spv.func @access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> - %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, Function> + %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function>, i32 + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, Function>, i32, i32 spv.Return } } @@ -31,13 +31,13 @@ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : f32 // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : f32 spv.Return } @@ -46,13 +46,13 @@ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spv.constant 0 : i32 - %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer>, i32, i32 %2 = spv.Load "StorageBuffer" %1 : i32 // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spv.constant 0 : i32 - %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Store "StorageBuffer" %4, %2 : i32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir @@ -16,7 +16,7 @@ // CHECK: {{%.*}} = spv.undef : !spv.ptr, StorageBuffer> %7 = spv.undef : !spv.ptr, StorageBuffer> %8 = spv.constant 0 : i32 - %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer> + %9 = spv.AccessChain %7[%8] : !spv.ptr, StorageBuffer>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -102,14 +102,14 @@ %37 = spv.IAdd %arg4, %11 : i32 // CHECK: spv.AccessChain [[ARG0]] %c0 = spv.constant 0 : i32 - %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 %39 = spv.Load "StorageBuffer" %38 : f32 // CHECK: spv.AccessChain [[ARG1]] - %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 %41 = spv.Load "StorageBuffer" %40 : f32 %42 = spv.FAdd %39, %41 : f32 // CHECK: spv.AccessChain [[ARG2]] - %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer> + %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr>>, StorageBuffer>, i32, i32, i32 spv.Store "StorageBuffer" %43, %42 : f32 spv.Return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -37,7 +37,7 @@ spv.func @callee() "None" { %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0: i32 - %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer>, i32, i32 spv.Branch ^next ^next: @@ -196,7 +196,7 @@ // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOADPTR]] %2 = spv._address_of @arg_0 : !spv.ptr, StorageBuffer> %3 = spv._address_of @arg_1 : !spv.ptr, StorageBuffer> - %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer> + %4 = spv.AccessChain %2[%1] : !spv.ptr, StorageBuffer>, i32 %5 = spv.Load "StorageBuffer" %4 : i32 %6 = spv.SGreaterThan %5, %1 : i32 // CHECK: spv.selection @@ -204,7 +204,7 @@ spv.BranchConditional %6, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: [[STOREPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG1]] - %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer> + %7 = spv.AccessChain %3[%1] : !spv.ptr, StorageBuffer>, i32 // CHECK-NOT: spv.FunctionCall // CHECK: spv.AtomicIAdd "Device" "AcquireRelease" [[STOREPTR]], [[VAL]] // CHECK: spv.Branch diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -24,7 +24,7 @@ // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12]>, Uniform> %0 = spv._address_of @var0 : !spv.ptr, f32>, Uniform> // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr [4], f32 [12]>, Uniform> - %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform> + %1 = spv.AccessChain %0[%c0] : !spv.ptr, f32>, Uniform>, i32 spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -11,8 +11,8 @@ // CHECK-NEXT: spv.Load "Function" %[[PTR]] %c0 = spv.constant 0: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr>, Function>, i32, i32 %3 = spv.Load "Function" %2 : f32 spv.ReturnValue %3 : f32 } @@ -28,9 +28,9 @@ // CHECK-NEXT: spv.Load "Function" %[[PTR_1]] %c0 = spv.constant 0: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function> - %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%c0] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %2 = spv.AccessChain %1[%c0] : !spv.ptr>, Function>, i32 + %3 = spv.AccessChain %2[%c0] : !spv.ptr, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xf32> %5 = spv.Load "Function" %3 : f32 spv.ReturnValue %4: !spv.array<4xf32> @@ -49,8 +49,8 @@ %c1 = spv.constant 1: i32 %0 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> %1 = spv.Variable : !spv.ptr>, !spv.array<4xi32>>, Function> - %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> - %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function> + %2 = spv.AccessChain %0[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 + %3 = spv.AccessChain %1[%c1] : !spv.ptr>, !spv.array<4xi32>>, Function>, i32 %4 = spv.Load "Function" %2 : !spv.array<4xi32> %5 = spv.Load "Function" %3 : !spv.array<4xi32> spv.ReturnValue %4 : !spv.array<4xi32> diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -97,8 +97,8 @@ // CHECK-LABEL: @cooperative_matrix_access_chain spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { %0 = spv.constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 spv.ReturnValue %1 : !spv.ptr } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -8,21 +8,21 @@ %0 = spv.constant 1: i32 %1 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 return } func @access_chain_1D_array(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function>, i32 return } func @access_chain_2D_array_1(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr>, Function>, i32, i32 %2 = spv.Load "Function" %1 ["Volatile"] : f32 return } @@ -30,7 +30,7 @@ func @access_chain_2D_array_2(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr>, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function>, i32 %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32> return } @@ -38,7 +38,7 @@ func @access_chain_rtarray(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> - %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function>, i32 %2 = spv.Load "Function" %1 ["Volatile"] : f32 return } @@ -49,7 +49,7 @@ %0 = spv.constant 1: i32 %1 = spv.Variable : !spv.ptr // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}} - %2 = spv.AccessChain %1[%0] : !spv.ptr + %2 = spv.AccessChain %1[%0] : !spv.ptr, i32 return } @@ -58,7 +58,34 @@ func @access_chain_no_indices(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{expected at least one index}} - %1 = spv.AccessChain %0[] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[] : !spv.ptr>, Function>, i32 + return +} + +// ----- + +func @access_chain_missing_comma(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{expected ','}} + %1 = spv.AccessChain %0[%index0] : !spv.ptr>, Function> i32 + return +} + +// ----- + +func @access_chain_invalid_indices_types_count(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}} + %1 = spv.AccessChain %0[%index0] : !spv.ptr>, Function>, i32, i32 + return +} + +// ----- + +func @access_chain_missing_indices_type(%index0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}} + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32 return } @@ -68,7 +95,7 @@ %0 = spv.Variable : !spv.ptr>, Function> %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> // expected-error @+1 {{expected a pointer to composite type, but provided '!spv.array<4 x !spv.array<4 x f32>>'}} - %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>> + %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>, i32 return } @@ -77,7 +104,7 @@ func @access_chain_invalid_index_1(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{expected SSA operand}} - %1 = spv.AccessChain %0[%index, 4] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index, 4] : !spv.ptr>, Function>, i32, i32 return } @@ -86,7 +113,7 @@ func @access_chain_invalid_index_2(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 return } @@ -96,7 +123,7 @@ %0 = std.constant 1: i32 %1 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}} - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 return } @@ -106,7 +133,7 @@ %index0 = "spv.constant"() { value = 12: i32} : () -> i32 %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct>'}} - %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr>, Function>, i32, i32 return } @@ -115,7 +142,7 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}} - %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr>, Function> + %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr>, Function>, i32, i32, i32 return // ----- diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -11,7 +11,7 @@ // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr>, Input> // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr>, Input> %1 = spv._address_of @var1 : !spv.ptr>, Input> - %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input>, i32, i32 spv.Return } }