diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -126,7 +126,21 @@ //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "spv.CompositeExtract expects one operand"); + if (auto insertOp = composite().getDefiningOp()) { + if (indices() == insertOp.indices()) + return insertOp.object(); + } + + if (auto constructOp = + composite().getDefiningOp()) { + auto type = constructOp.getType().cast(); + if (indices().size() == 1 && + constructOp.constituents().size() == type.getNumElements()) { + auto i = indices().begin()->cast(); + return constructOp.constituents()[i.getValue().getSExtValue()]; + } + } + auto indexVector = llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) { return static_cast(attr.cast().getInt()); diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -137,6 +137,47 @@ // ----- +// CHECK-LABEL: extract_insert +// CHECK-SAME: (%[[COMP:.+]]: !spv.array<1 x vector<2xf32>>, %[[VAL:.+]]: f32) +func.func @extract_insert(%composite: !spv.array<1xvector<2xf32>>, %val: f32) -> (f32, f32) { + // CHECK: %[[INSERT:.+]] = spv.CompositeInsert %[[VAL]], %[[COMP]] + %insert = spv.CompositeInsert %val, %composite[0 : i32, 1 : i32] : f32 into !spv.array<1xvector<2xf32>> + %1 = spv.CompositeExtract %insert[0 : i32, 0 : i32] : !spv.array<1xvector<2xf32>> + // CHECK: %[[S:.+]] = spv.CompositeExtract %[[INSERT]][0 : i32, 0 : i32] + %2 = spv.CompositeExtract %insert[0 : i32, 1 : i32] : !spv.array<1xvector<2xf32>> + // CHECK: return %[[S]], %[[VAL]] + return %1, %2 : f32, f32 +} + +// ----- + +// CHECK-LABEL: extract_construct +// CHECK-SAME: (%[[VAL1:.+]]: vector<2xf32>, %[[VAL2:.+]]: vector<2xf32>) +func.func @extract_construct(%val1: vector<2xf32>, %val2: vector<2xf32>) -> (vector<2xf32>, vector<2xf32>) { + %construct = spv.CompositeConstruct %val1, %val2 : (vector<2xf32>, vector<2xf32>) -> !spv.array<2xvector<2xf32>> + %1 = spv.CompositeExtract %construct[0 : i32] : !spv.array<2xvector<2xf32>> + %2 = spv.CompositeExtract %construct[1 : i32] : !spv.array<2xvector<2xf32>> + // CHECK: return %[[VAL1]], %[[VAL2]] + return %1, %2 : vector<2xf32>, vector<2xf32> +} + +// ----- + +// Not yet implemented case + +// CHECK-LABEL: extract_construct +func.func @extract_construct(%val1: vector<3xf32>, %val2: f32) -> (f32, f32) { + // CHECK: spv.CompositeConstruct + %construct = spv.CompositeConstruct %val1, %val2 : (vector<3xf32>, f32) -> vector<4xf32> + // CHECK: spv.CompositeExtract + %1 = spv.CompositeExtract %construct[0 : i32] : vector<4xf32> + // CHECK: spv.CompositeExtract + %2 = spv.CompositeExtract %construct[1 : i32] : vector<4xf32> + return %1, %2 : f32, f32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.Constant //===----------------------------------------------------------------------===//